Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ff5a7b67
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ff5a7b67
编写于
7月 17, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish
上级
a891708d
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
85 addition
and
65 deletion
+85
-65
paddle/fluid/framework/details/broadcast_op_handle_test.cc
paddle/fluid/framework/details/broadcast_op_handle_test.cc
+10
-5
paddle/fluid/framework/details/gather_op_handle_test.cc
paddle/fluid/framework/details/gather_op_handle_test.cc
+5
-5
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+49
-34
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+12
-9
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+1
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+1
-1
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+2
-2
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+3
-6
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle_test.cc
浏览文件 @
ff5a7b67
...
...
@@ -96,7 +96,8 @@ struct TestBroadcastOpHandle {
}
param_scopes_
[
input_scope_idx
]
->
Var
(
"input"
);
std
::
unique_ptr
<
ir
::
Node
>
n
(
new
ir
::
Node
(
"node0"
));
std
::
unique_ptr
<
ir
::
Node
>
n
(
new
ir
::
Node
(
"node0"
,
ir
::
Node
::
Type
::
kOperation
));
if
(
use_gpu_
)
{
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
BroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
gpu_list_
,
...
...
@@ -114,7 +115,8 @@ struct TestBroadcastOpHandle {
#endif
}
std
::
unique_ptr
<
ir
::
Node
>
v
(
new
ir
::
Node
(
"node1"
));
std
::
unique_ptr
<
ir
::
Node
>
v
(
new
ir
::
Node
(
"node1"
,
ir
::
Node
::
Type
::
kVariable
));
auto
*
in_var_handle
=
new
VarHandle
(
v
.
get
(),
1
,
input_scope_idx
,
"input"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
in_var_handle
);
...
...
@@ -122,7 +124,8 @@ struct TestBroadcastOpHandle {
// add dummy var
std
::
unique_ptr
<
ir
::
Node
>
v2
(
new
ir
::
Node
(
"node2"
));
std
::
unique_ptr
<
ir
::
Node
>
v2
(
new
ir
::
Node
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
));
vars_
.
emplace_back
(
new
DummyVarHandle
(
v2
.
get
()));
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
@@ -133,7 +136,8 @@ struct TestBroadcastOpHandle {
if
(
!
use_gpu_
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
}
std
::
unique_ptr
<
ir
::
Node
>
v3
(
new
ir
::
Node
(
"node3"
));
std
::
unique_ptr
<
ir
::
Node
>
v3
(
new
ir
::
Node
(
"node3"
,
ir
::
Node
::
Type
::
kVariable
));
VarHandle
*
out_var_handle
=
new
VarHandle
(
v3
.
get
(),
2
,
j
,
"out"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
out_var_handle
);
...
...
@@ -141,7 +145,8 @@ struct TestBroadcastOpHandle {
}
// add dummy var
std
::
unique_ptr
<
ir
::
Node
>
v4
(
new
ir
::
Node
(
"node4"
));
std
::
unique_ptr
<
ir
::
Node
>
v4
(
new
ir
::
Node
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
));
vars_
.
emplace_back
(
new
DummyVarHandle
(
v4
.
get
()));
DummyVarHandle
*
out_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
paddle/fluid/framework/details/gather_op_handle_test.cc
浏览文件 @
ff5a7b67
...
...
@@ -82,13 +82,13 @@ struct TestGatherOpHandle {
}
param_scopes_
[
input_scope_idx
]
->
Var
(
"out"
);
nodes
.
emplace_back
(
new
ir
::
Node
(
"node"
));
nodes
.
emplace_back
(
new
ir
::
Node
(
"node"
,
ir
::
Node
::
Type
::
kOperation
));
op_handle_
.
reset
(
new
GatherOpHandle
(
nodes
.
back
().
get
(),
local_scopes_
,
gpu_list_
));
// add input
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
nodes
.
emplace_back
(
new
ir
::
Node
(
"node1"
));
nodes
.
emplace_back
(
new
ir
::
Node
(
"node1"
,
ir
::
Node
::
Type
::
kVariable
));
auto
*
in_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
1
,
j
,
"input"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
in_var_handle
);
...
...
@@ -96,7 +96,7 @@ struct TestGatherOpHandle {
}
// add dummy var
nodes
.
emplace_back
(
new
ir
::
Node
(
"node2"
));
nodes
.
emplace_back
(
new
ir
::
Node
(
"node2"
,
ir
::
Node
::
Type
::
kVariable
));
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
DummyVarHandle
*
in_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
@@ -104,14 +104,14 @@ struct TestGatherOpHandle {
op_handle_
->
AddInput
(
in_dummy_var_handle
);
// add output
nodes
.
emplace_back
(
new
ir
::
Node
(
"node3"
));
nodes
.
emplace_back
(
new
ir
::
Node
(
"node3"
,
ir
::
Node
::
Type
::
kVariable
));
auto
*
out_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
2
,
input_scope_idx
,
"out"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
// add dummy var
nodes
.
emplace_back
(
new
ir
::
Node
(
"node4"
));
nodes
.
emplace_back
(
new
ir
::
Node
(
"node4"
,
ir
::
Node
::
Type
::
kVariable
));
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
ff5a7b67
...
...
@@ -80,7 +80,14 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
}
for
(
ir
::
Node
*
output
:
node
->
outputs
)
{
CreateOpOutput
(
result
,
op_handle
,
output
,
p
,
place_id
);
ir
::
Node
*
new_node
=
nullptr
;
if
(
output
->
Var
())
{
new_node
=
result
->
CreateVarNode
(
output
->
Var
());
}
else
{
new_node
=
result
->
CreateEmptyNode
(
output
->
Name
(),
ir
::
Node
::
Type
::
kVariable
);
}
CreateOpOutput
(
result
,
op_handle
,
new_node
,
p
,
place_id
);
}
}
...
...
@@ -246,7 +253,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
if
(
node
->
Op
()
->
Type
()
==
"read"
&&
strategy_
.
enable_data_balance_
)
{
node
->
Op
()
->
SetAttr
(
"throw_eof_exp"
,
false
);
CreateComputationalOps
(
&
result
,
node
.
get
(),
places_
.
size
());
// TODO(pa
nyx0718
): builder shouldn't depend on the out logic of
// TODO(pa
ddle-dev
): builder shouldn't depend on the out logic of
// a specific op.
const
auto
&
data_var_names
=
node
->
Op
()
->
Output
(
"Out"
);
InsertDataBalanceOp
(
&
result
,
data_var_names
);
...
...
@@ -354,11 +361,13 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
auto
*
op_handle
=
new
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
),
local_scopes_
,
places_
,
nccl_ctxs_
);
auto
*
op_handle
=
new
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
);
#else
auto
*
op_handle
=
new
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
),
local_scopes_
,
places_
);
auto
*
op_handle
=
new
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
#endif
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
...
...
@@ -370,8 +379,9 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
p_name
),
vars
.
size
(),
i
,
p_name
,
p
);
auto
*
out_var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
op_handle
->
AddOutput
(
out_var
);
}
...
...
@@ -389,12 +399,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
),
local_scopes_
,
places_
,
nccl_ctxs_
));
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
)
,
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
),
local_scopes_
,
places_
));
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
...
...
@@ -407,7 +418,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
op_handle
->
AddInput
(
prev_grad
.
get
());
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
),
vars
.
size
(),
i
,
og
,
p
);
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
og
,
p
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
...
...
@@ -416,12 +428,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
),
local_scopes_
,
places_
,
nccl_ctxs_
));
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
),
local_scopes_
,
places_
));
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
...
...
@@ -431,8 +444,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
d_name
),
vars
.
size
(),
i
,
d_name
,
p
);
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
d_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
d_name
,
p
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
...
...
@@ -487,8 +501,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
#endif
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
result
->
CreateEmptyNode
(
"scale_loss_grad"
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
...
...
@@ -497,14 +512,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
// TODO(panyx0718): GradVarName(loss_var_name_)
const
std
::
string
grad_var_name
=
GradVarName
(
loss_var_name_
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
grad_var_name
];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
grad_var_name
),
version
,
i
,
grad_var_name
,
places_
[
i
]);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
CreateOpOutput
(
result
,
op_handle
,
result
->
CreateEmptyNode
(
GradVarName
(
loss_var_name_
),
ir
::
Node
::
Type
::
kVariable
),
places_
[
i
],
i
);
}
}
...
...
@@ -525,10 +536,12 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
),
local_scopes_
,
places_
,
nccl_ctxs_
));
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
),
local_scopes_
,
places_
));
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
...
...
@@ -541,8 +554,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
op_handle
->
AddInput
(
prev_grad
.
get
());
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
dst_dev_id
][
og
];
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
return
var
;
...
...
@@ -554,7 +568,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateEmptyNode
(
"dummy"
));
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
ff5a7b67
...
...
@@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue
;
}
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateEmptyNode
(
"dummy"
));
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
...
...
@@ -54,12 +55,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
node
->
NodeType
()
==
ir
::
Node
::
Type
::
kVariable
)
{
if
(
node
->
Var
()
)
{
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
node
->
Name
(),
place
);
}
else
{
var
=
new
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
()),
0
,
place_offset
,
node
->
Name
(),
place
);
var
=
new
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
place_offset
,
node
->
Name
(),
place
);
}
var_holder
.
emplace_back
(
var
);
}
else
{
...
...
@@ -69,13 +71,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
}
void
SSAGraphBuilder
::
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
node
,
ir
::
Node
*
n
ew_n
ode
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
node
->
Name
()];
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
n
ew_n
ode
->
Name
()];
size_t
version
=
vars
.
size
();
auto
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
version
,
place_offset
,
node
->
Name
(),
place
);
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_
node
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
...
...
@@ -85,7 +87,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateEmptyNode
(
"dummy"
));
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
ff5a7b67
...
...
@@ -73,7 +73,7 @@ class SSAGraphBuilder : public ir::Pass {
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static
void
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
ir
::
Node
*
n
ew_n
ode
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
Graph
*
graph
);
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
ff5a7b67
...
...
@@ -173,7 +173,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
auto
&
var_name
=
fetch_tensors
[
i
];
auto
&
vars
=
fetched_vars
.
at
(
var_name
);
temp_nodes
->
emplace_back
(
new
ir
::
Node
(
"fetch"
));
temp_nodes
->
emplace_back
(
new
ir
::
Node
(
"fetch"
,
ir
::
Node
::
Type
::
kOperation
));
auto
*
op
=
new
FetchOpHandle
(
temp_nodes
->
back
().
get
(),
fetch_data
,
i
,
&
local_scopes_
);
fetch_ops
->
emplace_back
(
op
);
...
...
@@ -186,7 +186,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
op
->
AddInput
(
var
);
}
temp_nodes
->
emplace_back
(
new
ir
::
Node
(
"fetch"
));
temp_nodes
->
emplace_back
(
new
ir
::
Node
(
"fetch"
,
ir
::
Node
::
Type
::
kOperation
));
auto
*
fetch_dummy
=
new
DummyVarHandle
(
temp_nodes
->
back
().
get
());
op
->
AddOutput
(
fetch_dummy
);
fetch_dependencies
->
emplace
(
fetch_dummy
);
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
ff5a7b67
...
...
@@ -41,7 +41,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
// TODO(paddle-dev): Seems some assumption doesn't hold?
LOG
(
ERROR
)
<<
op
->
Type
()
<<
" input var not in all_var list: "
<<
each_var_name
;
var
=
CreateEmptyNode
(
each_var_name
);
var
=
CreateEmptyNode
(
each_var_name
,
ir
::
Node
::
Type
::
kVariable
);
var_nodes
[
each_var_name
]
=
var
;
}
node
->
inputs
.
push_back
(
var
);
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
ff5a7b67
...
...
@@ -67,8 +67,8 @@ class Graph {
// TODO(paddle-dev): There shouldn't be kNone nodes in the ir::Graph.
// node should either be a executable kOperation or a kVariable. kNone
// node is a temporary solution.
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
name
));
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
name
,
type
));
return
nodes
.
back
().
get
();
}
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
ff5a7b67
...
...
@@ -26,12 +26,9 @@ namespace ir {
class
Node
{
public:
enum
class
Type
{
kNone
,
kOperation
,
kVariable
};
explicit
Node
(
const
std
::
string
&
name
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
Type
::
kNone
)
{}
enum
class
Type
{
kOperation
,
kVariable
};
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
explicit
Node
(
VarDesc
*
var_desc
)
:
name_
(
var_desc
->
Name
()),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录