Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2fa8df1c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2fa8df1c
编写于
7月 13, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
separate graph building pass and graph-based pe builder
上级
37e51443
变更
15
展开全部
显示空白变更内容
内联
并排
Showing
15 changed file
with
273 addition
and
182 deletion
+273
-182
paddle/fluid/framework/details/broadcast_op_handle_test.cc
paddle/fluid/framework/details/broadcast_op_handle_test.cc
+5
-5
paddle/fluid/framework/details/gather_op_handle_test.cc
paddle/fluid/framework/details/gather_op_handle_test.cc
+5
-5
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+147
-113
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+11
-14
paddle/fluid/framework/details/reduce_op_handle_test.cc
paddle/fluid/framework/details/reduce_op_handle_test.cc
+3
-3
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+11
-15
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+6
-6
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+3
-3
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+3
-3
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+33
-0
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+19
-2
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+20
-5
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+4
-4
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-4
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+2
-0
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle_test.cc
浏览文件 @
2fa8df1c
...
...
@@ -96,7 +96,7 @@ struct TestBroadcastOpHandle {
}
param_scopes_
[
input_scope_idx
]
->
Var
(
"input"
);
std
::
unique_ptr
<
ir
::
Node
>
n
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kOperation
));
std
::
unique_ptr
<
ir
::
Node
>
n
(
new
ir
::
Node
());
if
(
use_gpu_
)
{
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
BroadcastOpHandle
(
n
.
get
(),
local_scopes_
,
gpu_list_
,
...
...
@@ -114,7 +114,7 @@ struct TestBroadcastOpHandle {
#endif
}
std
::
unique_ptr
<
ir
::
Node
>
v
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
ir
::
Node
>
v
(
new
ir
::
Node
());
auto
*
in_var_handle
=
new
VarHandle
(
v
.
get
(),
1
,
input_scope_idx
,
"input"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
in_var_handle
);
...
...
@@ -122,7 +122,7 @@ struct TestBroadcastOpHandle {
// add dummy var
std
::
unique_ptr
<
ir
::
Node
>
v2
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
ir
::
Node
>
v2
(
new
ir
::
Node
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
v2
.
get
()));
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
@@ -133,7 +133,7 @@ struct TestBroadcastOpHandle {
if
(
!
use_gpu_
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
}
std
::
unique_ptr
<
ir
::
Node
>
v3
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
ir
::
Node
>
v3
(
new
ir
::
Node
());
VarHandle
*
out_var_handle
=
new
VarHandle
(
v3
.
get
(),
2
,
j
,
"out"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
out_var_handle
);
...
...
@@ -141,7 +141,7 @@ struct TestBroadcastOpHandle {
}
// add dummy var
std
::
unique_ptr
<
ir
::
Node
>
v4
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
std
::
unique_ptr
<
ir
::
Node
>
v4
(
new
ir
::
Node
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
v4
.
get
()));
DummyVarHandle
*
out_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
paddle/fluid/framework/details/gather_op_handle_test.cc
浏览文件 @
2fa8df1c
...
...
@@ -82,13 +82,13 @@ struct TestGatherOpHandle {
}
param_scopes_
[
input_scope_idx
]
->
Var
(
"out"
);
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kOperation
));
nodes
.
emplace_back
(
new
ir
::
Node
());
op_handle_
.
reset
(
new
GatherOpHandle
(
nodes
.
back
().
get
(),
local_scopes_
,
gpu_list_
));
// add input
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
nodes
.
emplace_back
(
new
ir
::
Node
());
auto
*
in_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
1
,
j
,
"input"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
in_var_handle
);
...
...
@@ -96,7 +96,7 @@ struct TestGatherOpHandle {
}
// add dummy var
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
nodes
.
emplace_back
(
new
ir
::
Node
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
DummyVarHandle
*
in_dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
@@ -104,14 +104,14 @@ struct TestGatherOpHandle {
op_handle_
->
AddInput
(
in_dummy_var_handle
);
// add output
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
nodes
.
emplace_back
(
new
ir
::
Node
());
auto
*
out_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
2
,
input_scope_idx
,
"out"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
// add dummy var
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
nodes
.
emplace_back
(
new
ir
::
Node
());
vars_
.
emplace_back
(
new
DummyVarHandle
(
nodes
.
back
().
get
()));
DummyVarHandle
*
dummy_var_handle
=
static_cast
<
DummyVarHandle
*>
(
vars_
.
back
().
get
());
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
2fa8df1c
此差异已折叠。
点击以展开。
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
2fa8df1c
...
...
@@ -46,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
#endif
std
::
unique_ptr
<
Graph
>
Build
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
void
CreateOpHandleIOs
(
Graph
*
result
,
const
OpDesc
&
op
,
size_t
device_id
)
const
;
void
CreateOpHandleIOs
(
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
private:
std
::
string
loss_var_name_
;
...
...
@@ -64,40 +62,39 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
bool
IsScaleLossOp
(
const
OpDesc
&
op
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
Graph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateDistTrainOp
(
Graph
*
result
,
const
OpDesc
&
op
)
const
;
void
CreateRPCOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool
IsDistTrainOp
(
const
OpDesc
&
op
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
bool
IsDistTrainOp
(
ir
::
Node
*
node
,
const
std
::
vector
<
std
::
string
>
&
send_vars
,
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
ProgramDesc
&
program
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
&
nodes
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
ProgramDesc
&
program
)
const
;
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
&
nodes
)
const
;
void
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
void
CreateComputationalOps
(
Graph
*
result
,
const
OpDesc
&
op
,
void
CreateComputationalOps
(
Graph
*
result
,
ir
::
Node
*
node
,
size_t
num_places
)
const
;
void
CreateScaleLossGradOp
(
Graph
*
result
)
const
;
VarHandle
*
CreateReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
;
void
CreateComputationalOp
(
Graph
*
result
,
const
OpDesc
&
op
,
int
dev_id
)
const
;
void
CreateComputationalOp
(
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
bool
IsParameterGradientOnce
(
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
int
GetOpDeviceID
(
const
OpDesc
&
op
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/reduce_op_handle_test.cc
浏览文件 @
2fa8df1c
...
...
@@ -97,7 +97,7 @@ struct TestReduceOpHandle {
}
param_scopes_
[
out_scope_idx
]
->
Var
(
"out"
);
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kOperation
));
nodes
.
emplace_back
(
new
ir
::
Node
());
if
(
use_gpu_
)
{
#ifdef PADDLE_WITH_CUDA
op_handle_
.
reset
(
new
ReduceOpHandle
(
nodes
.
back
().
get
(),
local_scopes_
,
...
...
@@ -121,7 +121,7 @@ struct TestReduceOpHandle {
if
(
!
use_gpu_
)
{
op_handle_
->
SetDeviceContext
(
gpu_list_
[
j
],
ctxs_
[
j
].
get
());
}
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
nodes
.
emplace_back
(
new
ir
::
Node
());
auto
*
in_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
1
,
j
,
"input"
,
gpu_list_
[
j
]);
in_var_handle
->
ClearGeneratedOp
();
...
...
@@ -137,7 +137,7 @@ struct TestReduceOpHandle {
op_handle_
->
AddInput
(
in_dummy_var_handle
);
// add output
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
nodes
.
emplace_back
(
new
ir
::
Node
());
auto
*
out_var_handle
=
new
VarHandle
(
nodes
.
back
().
get
(),
2
,
out_scope_idx
,
"out"
,
gpu_list_
[
out_scope_idx
]);
vars_
.
emplace_back
(
out_var_handle
);
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
2fa8df1c
...
...
@@ -37,8 +37,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
continue
;
}
graph
->
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
nodes
.
back
().
get
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateVarNode
(
"dummy"
));
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
...
...
@@ -49,15 +48,14 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
const
std
::
string
&
each_var_nam
e
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
plac
e
,
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
each_var_name
];
auto
&
var_holder
=
var_holders
[
node
->
Var
()
->
Name
()
];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
graph
->
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
var
=
new
VarHandle
(
graph
->
nodes
.
back
().
get
(),
0
,
place_offset
,
each_var_name
,
place
);
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
node
->
Var
()
->
Name
(),
place
);
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
...
...
@@ -66,14 +64,13 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
}
void
SSAGraphBuilder
::
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_nam
e
,
ir
::
Node
*
nod
e
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
each_var_name
];
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
node
->
Var
()
->
Name
()
];
size_t
version
=
vars
.
size
();
graph
->
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
auto
var
=
new
VarHandle
(
graph
->
nodes
.
back
().
get
(),
version
,
place_offset
,
each_var_name
,
place
);
auto
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
version
,
place_offset
,
node
->
Var
()
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
...
...
@@ -85,8 +82,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
}
graph
->
nodes
.
emplace_back
(
new
ir
::
Node
(
ir
::
Node
::
Type
::
kVariable
));
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
nodes
.
back
().
get
());
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateVarNode
(
"dummy"
));
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
2fa8df1c
...
...
@@ -23,6 +23,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -34,11 +35,11 @@ typedef std::vector<
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
class
SSAGraphBuilder
{
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
std
::
unique_ptr
<
Graph
>
Build
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
virtual
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
...
...
@@ -53,16 +54,15 @@ class SSAGraphBuilder {
*/
static
void
PolishGraphToSupportDataHazards
(
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
const
std
::
string
&
each_var_name
,
static
VarHandle
*
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static
void
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
const
std
::
string
&
each_var_nam
e
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
ir
::
Node
*
node
,
const
platform
::
Place
&
plac
e
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
Graph
*
graph
);
};
...
...
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
2fa8df1c
...
...
@@ -28,10 +28,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
Graph
>
Build
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Build
(
std
::
move
(
graph
));
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
return
std
::
move
(
new_graph
)
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
...
...
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
2fa8df1c
...
...
@@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
Graph
>
Build
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Build
(
std
::
move
(
graph
));
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
return
new_graph
;
return
std
::
move
(
new_graph
)
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
...
...
paddle/fluid/framework/ir/graph.cc
浏览文件 @
2fa8df1c
...
...
@@ -13,12 +13,45 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
Graph
>
ProgramToGraph
(
const
ProgramDesc
&
program
)
{
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
program
));
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
for
(
auto
*
var
:
program
.
Block
(
0
).
AllVars
())
{
all_vars
.
emplace
(
var
->
Name
(),
var
);
}
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
ir
::
Node
*
node
=
graph
->
CreateOpNode
(
op
);
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
ir
::
Node
*
var
=
nullptr
;
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
var
=
graph
->
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
}
else
{
var
=
graph
->
CreateVarNode
(
each_var_name
);
}
node
->
inputs
.
push_back
(
var
);
var
->
outputs
.
push_back
(
node
);
}
for
(
auto
&
each_var_name
:
op
->
OutputArgumentNames
())
{
ir
::
Node
*
var
=
nullptr
;
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
var
=
graph
->
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
}
else
{
var
=
graph
->
CreateVarNode
(
each_var_name
);
}
node
->
outputs
.
push_back
(
var
);
var
->
inputs
.
push_back
(
node
);
}
}
return
std
::
move
(
graph
);
}
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
2fa8df1c
...
...
@@ -39,8 +39,6 @@ class Graph {
attr_dels_
.
clear
();
}
const
ProgramDesc
&
Program
()
const
{
return
program_
;
}
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
...
...
@@ -63,11 +61,30 @@ class Graph {
return
attr
;
}
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
var_desc
));
return
nodes
.
back
().
get
();
}
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
nodes
.
emplace_back
(
new
ir
::
Node
(
op_desc
));
return
nodes
.
back
().
get
();
}
// TODO(panyx0718): Need to handle CreateOpNode(nullptr).
ir
::
Node
*
CreateVarNode
(
const
std
::
string
&
var_name
)
{
var_descs_
.
emplace_back
(
new
VarDesc
(
var_name
));
nodes
.
emplace_back
(
new
ir
::
Node
(
var_descs_
.
back
().
get
()));
return
nodes
.
back
().
get
();
}
std
::
vector
<
ir
::
Node
*>
inputs
;
std
::
vector
<
ir
::
Node
*>
outputs
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
std
::
vector
<
std
::
unique_ptr
<
VarDesc
>>
var_descs_
;
private:
// NOTE: program_ shouldn't be exposed to user.
const
ProgramDesc
&
program_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
...
...
paddle/fluid/framework/ir/node.h
浏览文件 @
2fa8df1c
...
...
@@ -21,6 +21,8 @@ limitations under the License. */
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/variant.h"
...
...
@@ -32,10 +34,12 @@ class Node {
public:
enum
class
Type
{
kNone
=
-
1
,
kOperation
,
kVariable
};
Node
()
:
type_
(
Type
::
kNone
)
{}
explicit
Node
(
Type
type
)
:
type_
(
type
)
{}
virtual
~
Node
()
{
for
(
auto
&
attr
:
attrs_
)
{
for
(
auto
&
attr
:
attrs_
)
{
if
(
attr_dels_
.
find
(
attr
.
first
)
!=
attr_dels_
.
end
())
{
attr_dels_
[
attr
.
first
]();
}
...
...
@@ -47,23 +51,34 @@ class Node {
Type
NodeType
()
const
{
return
type_
;
}
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
name
,
AttrType
attr
)
{
void
Set
(
const
std
::
string
&
name
,
AttrType
attr
)
{
attrs_
[
name
]
=
attr
;
}
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
name
,
AttrType
*
attr
,
void
Set
(
const
std
::
string
&
name
,
AttrType
*
attr
,
std
::
function
<
void
(
void
)
>
attr_del
)
{
attrs_
[
name
]
=
attr
;
attr_dels_
[
name
]
=
attr_del
;
}
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
VarDesc
*
Var
()
{
return
var_desc_
;
}
OpDesc
*
Op
()
{
return
op_desc_
;
}
explicit
Node
(
VarDesc
*
var_desc
)
:
var_desc_
(
var_desc
),
op_desc_
(
nullptr
),
type_
(
Type
::
kVariable
)
{}
explicit
Node
(
OpDesc
*
op_desc
)
:
var_desc_
(
nullptr
),
op_desc_
(
op_desc
),
type_
(
Type
::
kOperation
)
{}
std
::
vector
<
Node
*>
inputs
;
std
::
vector
<
Node
*>
outputs
;
protected:
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
VarDesc
*
var_desc_
;
OpDesc
*
op_desc_
;
Type
type_
;
private:
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
2fa8df1c
...
...
@@ -20,15 +20,15 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Pass
{
public:
Pass
()
=
default
;
virtual
~
Pass
()
{}
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
{
return
std
::
move
(
graph
);
}
};
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
2fa8df1c
...
...
@@ -131,13 +131,10 @@ ParallelExecutor::ParallelExecutor(
PADDLE_THROW
(
"Not compiled with CUDA."
);
#endif
}
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
Graph
>
graph
=
builder_
->
Build
(
ProgramToGraph
(
main_program
));
std
::
unique_ptr
<
Graph
>
graph
=
builder_
->
Apply
(
ProgramToGraph
(
main_program
));
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
member_
->
places_
,
std
::
move
(
member_
->
executor_
)));
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
2fa8df1c
...
...
@@ -148,6 +148,7 @@ class ParallelExecutor(object):
lambda
var
:
var
.
persistable
and
var
.
type
!=
core
.
VarDesc
.
VarType
.
RAW
,
main
.
list_vars
())
]
sys
.
stderr
.
write
(
'!!!!!!!!before
\n
'
)
self
.
executor
=
core
.
ParallelExecutor
(
self
.
_places
,
...
...
@@ -158,6 +159,7 @@ class ParallelExecutor(object):
set
(
self
.
persistable_vars
),
main
.
desc
,
loss_name
if
loss_name
else
''
,
scope
,
local_scopes
,
exec_strategy
,
build_strategy
,
num_trainers
,
trainer_id
)
sys
.
stderr
.
write
(
'!!!!!!!!after
\n
'
)
self
.
scope
=
scope
def
run
(
self
,
fetch_list
,
feed
=
None
,
feed_dict
=
None
,
return_numpy
=
True
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录