Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
687a3222
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
687a3222
编写于
7月 27, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'ups/develop' into refine/im2col
上级
65d418f0
ec4c6e1f
变更
39
显示空白变更内容
内联
并排
Showing
39 changed file
with
734 addition
and
253 deletion
+734
-253
doc/fluid/design/ir/draft.md
doc/fluid/design/ir/draft.md
+24
-24
paddle/fluid/framework/block_desc.h
paddle/fluid/framework/block_desc.h
+2
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-2
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+101
-50
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+17
-14
paddle/fluid/framework/details/rpc_op_handle.cc
paddle/fluid/framework/details/rpc_op_handle.cc
+2
-1
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+16
-8
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+9
-12
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+1
-1
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+3
-2
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+2
-2
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+4
-3
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+2
-2
paddle/fluid/framework/details/var_handle.cc
paddle/fluid/framework/details/var_handle.cc
+1
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+3
-2
paddle/fluid/framework/ir/graph.cc
paddle/fluid/framework/ir/graph.cc
+68
-17
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+58
-16
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+118
-0
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+40
-0
paddle/fluid/framework/ir/graph_helper_test.cc
paddle/fluid/framework/ir/graph_helper_test.cc
+125
-0
paddle/fluid/framework/ir/graph_test.cc
paddle/fluid/framework/ir/graph_test.cc
+17
-15
paddle/fluid/framework/ir/node.cc
paddle/fluid/framework/ir/node.cc
+5
-1
paddle/fluid/framework/ir/node.h
paddle/fluid/framework/ir/node.h
+3
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-1
paddle/fluid/inference/api/demo_ci/clean.sh
paddle/fluid/inference/api/demo_ci/clean.sh
+4
-0
paddle/fluid/memory/detail/buddy_allocator.cc
paddle/fluid/memory/detail/buddy_allocator.cc
+11
-6
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-2
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+13
-13
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
+9
-9
paddle/fluid/operators/math/softmax.cu
paddle/fluid/operators/math/softmax.cu
+2
-2
paddle/fluid/operators/pool_cudnn_op.cu.cc
paddle/fluid/operators/pool_cudnn_op.cu.cc
+2
-2
paddle/fluid/operators/send_recv_util.h
paddle/fluid/operators/send_recv_util.h
+5
-1
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+6
-7
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+1
-0
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+28
-28
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+22
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+0
-3
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+3
-1
未找到文件。
doc/fluid/design/ir/draft.md
浏览文件 @
687a3222
## Motivation
## Motivation
There is a
`
``gap```
between the
```Program``
`
defined by
There is a
`
gap`
between the
`Program
`
defined by
user and the
`
``Executable``
`
that can be scheduled
user and the
`
Executable
`
that can be scheduled
efficiently on heterogeneous hardware, either locally
efficiently on heterogeneous hardware, either locally
or distributedly.
or distributedly.
Usually, the
`
``gap``
`
is bridged by
Usually, the
`
gap
`
is bridged by
*
A serious transformations with defined order.
*
A serious transformations with defined order.
*
These transformations usually involve
*
These transformations usually involve
`
``
insert, delete, clustering, split, dependency analysis``
`.
`
insert, delete, clustering, split, dependency analysis
`
.
*
Has a simple way to verify and debug each transformation.
*
Has a simple way to verify and debug each transformation.
...
@@ -38,44 +38,44 @@ design below.
...
@@ -38,44 +38,44 @@ design below.
#### Node
#### Node
`
``
Node
``
` represents an operation that performs some computation or
`
Node
`
represents an operation that performs some computation or
a variable that is input or output of operation.
a variable that is input or output of operation.
`
``
Node
```s are connected to other ```
Node
``
`s via inputs and outputs.
`
Node`
s are connected to other
`Node
`
s via inputs and outputs.
Other properties (maybe device placement information) can be added
Other properties (maybe device placement information) can be added
to `
``
Node
``
` in the future if it's a
to
`
Node
`
in the future if it's a
common requirement of many other `
``
Pass
``
`es. Otherwise, it should live
common requirement of many other
`
Pass
`
es. Otherwise, it should live
in a `
``
Node
``` wrapper class that is private to some ```
Pass
``
` or be
in a
`
Node`
wrapper class that is private to some
`Pass
`
or be
a local member of a `
``
Pass
``
`.
a local member of a
`
Pass
`
.
#### Graph
#### Graph
`
``
Graph
``` contains a list of ```
Node
``
`s, which are connected to
`
Graph`
contains a list of
`Node
`
s, which are connected to
each other via inputs and outputs.
each other via inputs and outputs.
TODO: Better definitions for the graph.
TODO: Better definitions for the graph.
`
``
Graph
``` can also contain ```
Attribute
```s. ```
Attribute
``
`s
`
Graph`
can also contain
`Attribute`
s.
`Attribute
`
s
can be `
`any`
` thing. For example, it can be a list of "wraper"
can be
`
any
`
thing. For example, it can be a list of "wraper"
nodes. The `
``
wrapper
``` nodes compose ```
Node
``
`s and provide
nodes. The
`
wrapper`
nodes compose
`Node
`
s and provide
helper method for execution or transformation. `
``
Attribute
``
`
helper method for execution or transformation.
`
Attribute
`
can also contain other things that describe some properties of
can also contain other things that describe some properties of
the `
``
Graph
``` or ```
Graph
``` nodes. ```
Attribute
``
` can be passed
the
`
Graph`
or
`Graph`
nodes.
`Attribute
`
can be passed
across `
``
Pass
``
`. However, it should be used with care.
across
`
Pass
`
. However, it should be used with care.
#### Pass
#### Pass
`
``
Pass
``` represents a transformation of ```
Graph
``
`. Its input
`
Pass`
represents a transformation of
`Graph
`
. Its input
is a `
``
Graph
``` and its output is also a ```
Graph
``
`. For example,
is a
`
Graph`
and its output is also a
`Graph
`
. For example,
a `
``
Pass
``` can simply print out the ```
Graph
```. A ```
Pass
``
`
a
`
Pass`
can simply print out the
`Graph`
. A
`Pass
`
can also fuse some `
``
Graph
```'s ```
Node
``
`s.
can also fuse some
`
Graph`
's
`Node
`
s.
#### Optimize
#### Optimize
`
``
Optimize
``` contains a series of ```
Pass
``
` with defined order.
`
Optimize`
contains a series of
`Pass
`
with defined order.
`
``
Optimize
``` transforms a ```
Graph
``
` that only contains raw
`
Optimize`
transforms a
`Graph
`
that only contains raw
modeling logic to a `
``
Graph
``
`
that can be run efficiently while
modeling logic to a
`
Graph
`
that can be run efficiently while
maintaining the original modeling logic.
maintaining the original modeling logic.
...
...
paddle/fluid/framework/block_desc.h
浏览文件 @
687a3222
...
@@ -88,9 +88,8 @@ class BlockDesc {
...
@@ -88,9 +88,8 @@ class BlockDesc {
OpDesc
*
InsertOp
(
size_t
index
);
OpDesc
*
InsertOp
(
size_t
index
);
/*
/*
* Remove Op and its input/output variables.
* Only remove op itself,
* Note that for either input or output variable, if it is also an input or
* do nothing to its input and output variables
* output variable of other ops, we should remain it.
*/
*/
void
RemoveOp
(
size_t
s
,
size_t
e
);
void
RemoveOp
(
size_t
s
,
size_t
e
);
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
687a3222
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto
)
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto
node
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph
graph_helper
)
cc_library
(
ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder
)
cc_library
(
ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder
)
cc_library
(
ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder
)
cc_library
(
ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder
)
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
687a3222
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -67,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
...
@@ -67,7 +68,8 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
Graph
*
result
,
ir
::
Node
*
node
,
void
MultiDevSSAGraphBuilder
::
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
...
@@ -92,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
...
@@ -92,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node,
}
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainSendVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
{
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
{
std
::
vector
<
std
::
string
>
send_vars
;
std
::
vector
<
std
::
string
>
send_vars
;
// since parameters are all in block 0,
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
// it's enough to only scan send ops in block 0
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
OpDesc
*
op
=
node
->
Op
();
OpDesc
*
op
=
node
->
Op
();
// TODO(Yancey1989): use a graceful method to find send op,
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
// instead of the the hard code string
...
@@ -112,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
...
@@ -112,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
}
}
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
std
::
vector
<
std
::
string
>
MultiDevSSAGraphBuilder
::
FindDistTrainRecvVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
{
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
{
std
::
vector
<
std
::
string
>
recv_vars
;
std
::
vector
<
std
::
string
>
recv_vars
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
OpDesc
*
op
=
node
->
Op
();
OpDesc
*
op
=
node
->
Op
();
// TODO(Yancey1989): use a graceful method to find recv op,
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
// instead of the hard code string
...
@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
...
@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
{
int64_t
numel_sum
=
0
;
int64_t
numel_sum
=
0
;
for
(
auto
var_name
:
var_names
)
{
for
(
auto
var_name
:
var_names
)
{
if
(
all_vars_
.
find
(
var_name
)
==
all_vars_
.
end
())
continue
;
auto
var_desc
=
all_vars_
.
at
(
var_name
);
auto
var_desc
=
all_vars_
.
at
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
auto
dim
=
framework
::
make_ddim
(
var_desc
->
GetShape
());
...
@@ -186,19 +187,70 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
...
@@ -186,19 +187,70 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
return
dev_id
;
return
dev_id
;
}
}
std
::
unique_ptr
<
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
// Topology sort the graph nodes from inputs to outputs.
std
::
unique_ptr
<
Graph
>
graph
)
const
{
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// Rebuild the graph structure.
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
auto
nodes
=
std
::
move
(
graph
->
nodes
);
// some optimizer ops might not depend on any nodes), we manually move all
graph
->
nodes
.
clear
();
// optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std
::
vector
<
ir
::
Node
*>
SortOpsAndDelayOptimizeOp
(
const
ir
::
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
ret
=
ir
::
TopologySortOperations
(
graph
);
size_t
last_backward
=
0
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kBackward
))
{
last_backward
=
i
;
}
}
std
::
vector
<
ir
::
Node
*>
optimize_ops
;
std
::
vector
<
ir
::
Node
*>
sorted_ret
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
i
<
last_backward
)
{
if
(
boost
::
get
<
int
>
(
ret
[
i
]
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
optimize_ops
.
push_back
(
ret
[
i
]);
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
}
}
else
if
(
i
==
last_backward
)
{
sorted_ret
.
push_back
(
ret
[
i
]);
// Verify that no operations before optimize ops depends on optimize ops.
std
::
unordered_set
<
ir
::
Node
*>
optimize_set
(
optimize_ops
.
begin
(),
optimize_ops
.
end
());
for
(
ir
::
Node
*
n
:
sorted_ret
)
{
for
(
ir
::
Node
*
in
:
n
->
inputs
)
{
for
(
ir
::
Node
*
pre_n
:
in
->
inputs
)
{
PADDLE_ENFORCE
(
optimize_set
.
find
(
pre_n
)
==
optimize_set
.
end
(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s"
,
pre_n
->
Name
(),
n
->
Name
());
}
}
}
sorted_ret
.
insert
(
sorted_ret
.
end
(),
optimize_ops
.
begin
(),
optimize_ops
.
end
());
}
else
{
sorted_ret
.
push_back
(
ret
[
i
]);
}
}
return
sorted_ret
;
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
graph
->
ReleaseNodes
();
ir
::
Graph
&
result
=
*
graph
;
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
==
ir
::
Node
::
Type
::
kVariable
)
{
if
(
node
->
NodeType
()
==
ir
::
Node
::
Type
::
kVariable
&&
node
->
Var
()
)
{
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
all_vars_
.
emplace
(
node
->
Name
(),
node
->
Var
());
}
}
}
}
Graph
&
result
=
*
graph
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
...
@@ -207,9 +259,9 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -207,9 +259,9 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
Set
(
"ops"
,
new
GraphOps
);
// find send/recv vars so that we can place the distributed training
// find send/recv vars so that we can place the distributed training
// re
al
ted op in the place 0
// re
la
ted op in the place 0
auto
send_vars
=
FindDistTrainSendVars
(
node
s
);
auto
send_vars
=
FindDistTrainSendVars
(
sorted_op
s
);
auto
recv_vars
=
FindDistTrainRecvVars
(
node
s
);
auto
recv_vars
=
FindDistTrainRecvVars
(
sorted_op
s
);
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
bcast_var_name_set
;
bcast_var_name_set
.
resize
(
places_
.
size
());
bcast_var_name_set
.
resize
(
places_
.
size
());
...
@@ -217,22 +269,18 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -217,22 +269,18 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
// NOTE: Currently, passes before SSAGraphBuilder cannot reorder
for
(
ir
::
Node
*
node
:
sorted_ops
)
{
// forward, backward nodes. E.g. you can't append an forward node
// at the end of the node list.
// TODO(panyx0718): FIXME: Needs to sort by forward->backward order.
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
if
(
boost
::
get
<
int
>
(
if
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
==
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
static_cast
<
int
>
(
OpRole
::
kRPC
))
{
CreateRPCOp
(
&
result
,
node
.
get
()
);
CreateRPCOp
(
&
result
,
node
);
}
else
if
(
IsDistTrainOp
(
node
.
get
()
,
send_vars
,
recv_vars
))
{
}
else
if
(
IsDistTrainOp
(
node
,
send_vars
,
recv_vars
))
{
CreateDistTrainOp
(
&
result
,
node
.
get
()
);
CreateDistTrainOp
(
&
result
,
node
);
}
else
if
(
IsScaleLossOp
(
node
.
get
()
))
{
}
else
if
(
IsScaleLossOp
(
node
))
{
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
if
(
strategy_
.
gradient_scale_
!=
if
(
strategy_
.
gradient_scale_
!=
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
)
{
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
)
{
// TODO(paddle-dev): Why is there no input for this op_handle?
CreateScaleLossGradOp
(
&
result
);
CreateScaleLossGradOp
(
&
result
);
}
}
// This assumes the backward generating code will ensure IsScaleLossOp
// This assumes the backward generating code will ensure IsScaleLossOp
...
@@ -241,24 +289,23 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -241,24 +289,23 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
// the block.
// the block.
is_forwarding
=
false
;
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
.
get
()
);
int
op_dev_id
=
GetOpDeviceID
(
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
.
get
()
,
op_dev_id
);
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
var_name_on_devices_
.
emplace
(
n
->
Name
(),
op_dev_id
);
var_name_on_devices_
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
}
else
{
}
else
{
// This op runs on all devices, and its output may have parameter's
// This op runs on all devices, and its output may have parameter's
// gradients.
// gradients.
// TODO(paddle-dev): Why is so special about "read" op?
if
(
node
->
Op
()
->
Type
()
==
"read"
&&
strategy_
.
enable_data_balance_
)
{
if
(
node
->
Op
()
->
Type
()
==
"read"
&&
strategy_
.
enable_data_balance_
)
{
node
->
Op
()
->
SetAttr
(
"throw_eof_exp"
,
false
);
node
->
Op
()
->
SetAttr
(
"throw_eof_exp"
,
false
);
CreateComputationalOps
(
&
result
,
node
.
get
(),
places_
.
size
());
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
// TODO(paddle-dev): builder shouldn't depend on the out logic of
// a specific op.
const
auto
&
data_var_names
=
node
->
Op
()
->
Output
(
"Out"
);
const
auto
&
data_var_names
=
node
->
Op
()
->
Output
(
"Out"
);
InsertDataBalanceOp
(
&
result
,
data_var_names
);
InsertDataBalanceOp
(
&
result
,
data_var_names
);
}
else
{
}
else
{
CreateComputationalOps
(
&
result
,
node
.
get
()
,
places_
.
size
());
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
}
}
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
...
@@ -322,7 +369,6 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -322,7 +369,6 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
}
}
}
}
}
}
/*
/*
Dependency graph has been constructed. However, there are still data
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
hazards need to be handled.
...
@@ -333,6 +379,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -333,6 +379,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
* Only variables should be the leaves of graph.
* Only variables should be the leaves of graph.
*/
*/
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
PADDLE_ENFORCE
(
!
ir
::
HasCircle
(
result
));
return
graph
;
return
graph
;
}
}
...
@@ -357,7 +404,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
...
@@ -357,7 +404,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
#endif
}
}
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
size_t
src_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -387,7 +434,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
...
@@ -387,7 +434,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
...
@@ -396,7 +443,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
...
@@ -396,7 +443,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
}
}
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
...
@@ -426,7 +473,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
...
@@ -426,7 +473,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
}
}
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -479,8 +526,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
...
@@ -479,8 +526,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s]"
,
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s
, %s
]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
]);
node
->
Op
()
->
Type
(),
param_grad
[
0
]
,
param_grad
[
1
]
);
return
dev_id
;
return
dev_id
;
}
}
...
@@ -489,7 +536,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
...
@@ -489,7 +536,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
Graph
*
result
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -519,7 +566,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
...
@@ -519,7 +566,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
num_places
)
const
{
size_t
num_places
)
const
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
...
@@ -531,7 +578,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
...
@@ -531,7 +578,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
}
}
}
}
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
Graph
*
result
,
VarHandle
*
MultiDevSSAGraphBuilder
::
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
@@ -564,12 +611,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
...
@@ -564,12 +611,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
result
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
prev_op
->
AddOutput
(
dep_var
);
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
op
->
AddInput
(
dep_var
);
...
@@ -577,7 +623,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
...
@@ -577,7 +623,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
}
}
}
}
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
...
@@ -591,6 +637,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
...
@@ -591,6 +637,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
...
@@ -624,10 +671,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
...
@@ -624,10 +671,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
}
}
// Create RPC related op handles that connects its in ops and out ops.
// Create RPC related op handles that connects its in ops and out ops.
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
// split_byref op
// split_byref op
// so that we can balance the variable blocks to all the pserver
// so that we can balance the variable blocks to all the pserver
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
687a3222
...
@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -46,11 +46,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
const
BuildStrategy
&
strategy
);
#endif
#endif
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
private:
void
CreateOpHandleIOs
(
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
private:
private:
std
::
string
loss_var_name_
;
std
::
string
loss_var_name_
;
...
@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -64,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateDistTrainOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
/**
/**
* Is this operator as the end-point operator before/after send operator.
* Is this operator as the end-point operator before/after send operator.
...
@@ -74,21 +76,22 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -74,21 +76,22 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
const
std
::
vector
<
std
::
string
>
&
recv_vars
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
std
::
vector
<
std
::
string
>
FindDistTrainSendVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
;
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
;
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
std
::
vector
<
std
::
string
>
FindDistTrainRecvVars
(
const
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>
>
&
nodes
)
const
;
const
std
::
vector
<
ir
::
Node
*
>
&
nodes
)
const
;
void
ConnectOp
(
Graph
*
result
,
OpHandleBase
*
op
,
void
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
;
const
std
::
string
&
prev_op_name
)
const
;
void
CreateComputationalOps
(
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOps
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
num_places
)
const
;
size_t
num_places
)
const
;
void
CreateScaleLossGradOp
(
Graph
*
result
)
const
;
void
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
;
VarHandle
*
CreateReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
,
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
;
int
dst_dev_id
)
const
;
void
CreateComputationalOp
(
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
dev_id
)
const
;
bool
IsParameterGradientOnce
(
bool
IsParameterGradientOnce
(
const
std
::
string
&
og
,
const
std
::
string
&
og
,
...
@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -96,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertDataBalanceOp
(
Graph
*
result
,
void
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
void
CreateBroadcastOp
(
Graph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
...
...
paddle/fluid/framework/details/rpc_op_handle.cc
浏览文件 @
687a3222
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -33,7 +34,7 @@ void RPCOpHandle::RunImpl() {
...
@@ -33,7 +34,7 @@ void RPCOpHandle::RunImpl() {
for
(
auto
*
in
:
inputs_
)
{
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if
(
i
n
->
DebugString
()
==
"dummy"
)
{
// HACK
if
(
i
r
::
IsControlDepVar
(
*
in
->
Node
())
)
{
// HACK
continue
;
continue
;
}
}
if
(
in
->
GeneratedOp
())
{
if
(
in
->
GeneratedOp
())
{
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
687a3222
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
Graph
*
graph
)
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
...
@@ -36,9 +36,18 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
...
@@ -36,9 +36,18 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
// Read Write is the same op.
// Read Write is the same op.
continue
;
continue
;
}
}
bool
has_dep
=
false
;
for
(
auto
*
r_out
:
read_op
->
Outputs
())
{
for
(
auto
*
w_in
:
write_op
->
Inputs
())
{
if
(
r_out
->
Node
()
==
w_in
->
Node
())
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
auto
*
dep_var
=
new
DummyVarHandle
(
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
read_op
->
AddOutput
(
dep_var
);
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
...
@@ -49,7 +58,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
...
@@ -49,7 +58,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
}
}
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
...
@@ -70,7 +79,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
...
@@ -70,7 +79,7 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return
var
;
return
var
;
}
}
void
SSAGraphBuilder
::
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
void
SSAGraphBuilder
::
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
...
@@ -82,13 +91,12 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
...
@@ -82,13 +91,12 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
}
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
Graph
*
graph
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
CreateEmptyNode
(
"dummy"
,
ir
::
Node
::
Type
::
kVariable
));
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
687a3222
...
@@ -57,26 +57,23 @@ class SSAGraphBuilder : public ir::Pass {
...
@@ -57,26 +57,23 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
protected:
/**
/*
* We only handle write after read(WAR), since it should not have a write
Dependency graph has been constructed. However, there are still data
* after write in program. If there are write after write operators, we need
hazards need to be handled.
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
*/
static
void
PolishGraphToSupportDataHazards
(
Graph
*
graph
);
static
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
);
static
VarHandle
*
CreateOrGetLatestVarHandle
(
Graph
*
graph
,
ir
::
Node
*
node
,
static
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
size_t
place_offset
);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
// which belongs to graph
static
void
CreateOpOutput
(
Graph
*
graph
,
OpHandleBase
*
op_handle
,
static
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
);
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
Graph
*
graph
);
static
void
AddOutputToLeafOps
(
ir
::
Graph
*
graph
);
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
687a3222
...
@@ -20,7 +20,7 @@ namespace paddle {
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
bool
SSAGraghBuilderWithChecker
::
IsValidGraph
(
const
Graph
*
graph
)
const
{
bool
SSAGraghBuilderWithChecker
::
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
...
...
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
687a3222
...
@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
...
@@ -28,7 +28,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
return
new_graph
;
...
@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
...
@@ -38,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
return
builder_
->
GetVarDeviceID
(
var_name
);
return
builder_
->
GetVarDeviceID
(
var_name
);
}
}
bool
IsValidGraph
(
const
Graph
*
graph
)
const
;
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
private:
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
687a3222
...
@@ -21,7 +21,7 @@ namespace framework {
...
@@ -21,7 +21,7 @@ namespace framework {
namespace
details
{
namespace
details
{
template
<
typename
Callback
>
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
Graph
&
graph
,
Callback
callback
)
{
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
...
@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
...
@@ -35,7 +35,7 @@ static inline void IterAllVar(const Graph &graph, Callback callback) {
}
}
}
}
void
GraphvizSSAGraphPrinter
::
Print
(
const
Graph
&
graph
,
void
GraphvizSSAGraphPrinter
::
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
{
std
::
ostream
&
sout
)
const
{
size_t
var_id
=
0
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
687a3222
...
@@ -25,12 +25,12 @@ namespace details {
...
@@ -25,12 +25,12 @@ namespace details {
class
SSAGraphPrinter
{
class
SSAGraphPrinter
{
public:
public:
virtual
~
SSAGraphPrinter
()
{}
virtual
~
SSAGraphPrinter
()
{}
virtual
void
Print
(
const
Graph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
virtual
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
};
};
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
public:
public:
void
Print
(
const
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
...
@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
...
@@ -50,7 +50,8 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
stream_ptr_
(
std
::
move
(
sout
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
return
new_graph
;
return
new_graph
;
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
687a3222
...
@@ -21,7 +21,8 @@ namespace framework {
...
@@ -21,7 +21,8 @@ namespace framework {
namespace
details
{
namespace
details
{
ThreadedSSAGraphExecutor
::
ThreadedSSAGraphExecutor
(
ThreadedSSAGraphExecutor
::
ThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
Graph
>
&&
graph
)
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
)
:
graph_
(
std
::
move
(
graph
)),
:
graph_
(
std
::
move
(
graph
)),
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
pool_
(
strategy
.
num_threads_
>=
2
?
new
::
ThreadPool
(
strategy
.
num_threads_
)
:
nullptr
),
:
nullptr
),
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
687a3222
...
@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -40,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
ThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
Graph
>
&&
graph
);
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
// Run a SSAGraph by a thread pool
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
// Use topological sort algorithm
...
@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -53,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details
::
OpHandleBase
*
op
);
details
::
OpHandleBase
*
op
);
private:
private:
std
::
unique_ptr
<
Graph
>
graph_
;
std
::
unique_ptr
<
ir
::
Graph
>
graph_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
...
...
paddle/fluid/framework/details/var_handle.cc
浏览文件 @
687a3222
...
@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
...
@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
return
ss
.
str
();
return
ss
.
str
();
}
}
std
::
string
DummyVarHandle
::
DebugString
()
const
{
return
"dummy"
;
}
std
::
string
DummyVarHandle
::
DebugString
()
const
{
return
node_
->
Name
()
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
687a3222
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
pass SRCS pass.cc DEPS graph node
)
cc_library
(
pass SRCS pass.cc DEPS graph node
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph op_registry
)
cc_test
(
graph_
test SRCS graph_test.cc DEPS graph proto_desc
op_registry
)
cc_test
(
graph_
helper_test SRCS graph_helper_test.cc DEPS graph_helper
op_registry
)
paddle/fluid/framework/ir/graph.cc
浏览文件 @
687a3222
...
@@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
// NOTE(paddle-dev): This graph contains circle.
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
Graph
::
Graph
(
const
ProgramDesc
&
program
)
:
program_
(
program
)
{
VLOG
(
3
)
<<
"block in program:"
<<
program_
.
Size
();
VLOG
(
3
)
<<
"block in program:"
<<
program_
.
Size
();
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars
;
...
@@ -27,40 +31,87 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
...
@@ -27,40 +31,87 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
all_vars
.
emplace
(
var
->
Name
(),
var
);
all_vars
.
emplace
(
var
->
Name
(),
var
);
}
}
std
::
map
<
std
::
string
,
ir
::
Node
*
>
var_nodes
;
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>
>
var_nodes
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
ir
::
Node
*
node
=
CreateOpNode
(
op
);
ir
::
Node
*
node
=
CreateOpNode
(
op
);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
for
(
auto
&
each_var_name
:
op
->
InputArgumentNames
())
{
ir
::
Node
*
var
=
nullptr
;
ir
::
Node
*
var
=
nullptr
;
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
var
=
var_nodes
.
at
(
each_var_name
);
var
=
var_nodes
.
at
(
each_var_name
)
.
back
()
;
}
else
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
}
else
if
(
all_vars
.
count
(
each_var_name
)
!=
0
)
{
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
var_nodes
[
each_var_name
]
=
var
;
var_nodes
[
each_var_name
]
.
push_back
(
var
)
;
}
else
{
}
else
{
//
TODO(paddle-dev): Seems some assumption doesn't hold?
//
Operation input var can be optional (dispensable). Which means
VLOG
(
3
)
<<
op
->
Type
()
// the operation doesn't really need the var at runtime. In this
<<
" input var not in all_var list: "
<<
each_var_name
;
// case, the no-existed var is ready at the beginning.
var
=
CreateEmptyNode
(
each_var_name
,
ir
::
Node
::
Type
::
kVariable
);
var
=
CreateEmptyNode
(
each_var_name
,
ir
::
Node
::
Type
::
kVariable
);
var_nodes
[
each_var_name
]
=
var
;
var_nodes
[
each_var_name
]
.
push_back
(
var
)
;
}
}
node
->
inputs
.
push_back
(
var
);
node
->
inputs
.
push_back
(
var
);
var
->
outputs
.
push_back
(
node
);
var
->
outputs
.
push_back
(
node
);
}
}
// For output args, always create a new var.
for
(
auto
&
each_var_name
:
op
->
OutputArgumentNames
())
{
for
(
auto
&
each_var_name
:
op
->
OutputArgumentNames
())
{
ir
::
Node
*
var
=
nullptr
;
ir
::
Node
*
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
if
(
var_nodes
.
find
(
each_var_name
)
!=
var_nodes
.
end
())
{
var_nodes
[
each_var_name
].
push_back
(
var
);
var
=
var_nodes
.
at
(
each_var_name
);
}
else
{
var
=
CreateVarNode
(
all_vars
.
at
(
each_var_name
));
var_nodes
[
each_var_name
]
=
var
;
}
node
->
outputs
.
push_back
(
var
);
node
->
outputs
.
push_back
(
var
);
var
->
inputs
.
push_back
(
node
);
var
->
inputs
.
push_back
(
node
);
}
}
}
}
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
for
(
auto
&
var
:
var_nodes
)
{
auto
&
versions
=
var
.
second
;
if
(
versions
.
size
()
<=
1
)
continue
;
auto
it_new
=
versions
.
rbegin
();
auto
it_old
=
versions
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
versions
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
ir
::
Node
*
write_op
=
(
*
it_new
)
->
inputs
.
empty
()
?
nullptr
:
(
*
it_new
)
->
inputs
[
0
];
const
auto
&
read_ops
=
(
*
it_old
)
->
outputs
;
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
if
(
read_op
==
write_op
)
{
// Read Write is the same op.
continue
;
}
// 2 ops might have been connected via other vars.
bool
has_dep
=
false
;
for
(
ir
::
Node
*
r_out
:
read_op
->
outputs
)
{
for
(
ir
::
Node
*
w_in
:
write_op
->
inputs
)
{
if
(
r_out
==
w_in
)
{
has_dep
=
true
;
break
;
}
}
}
if
(
has_dep
)
continue
;
ir
::
Node
*
dep_var
=
CreateControlDepVar
();
read_op
->
outputs
.
push_back
(
dep_var
);
dep_var
->
inputs
.
push_back
(
read_op
);
write_op
->
inputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
write_op
);
}
}
}
}
bool
IsControlDepVar
(
const
ir
::
Node
&
var
)
{
return
var
.
Name
().
find
(
ir
::
Node
::
kControlDepVarName
)
!=
std
::
string
::
npos
;
}
}
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph.h
浏览文件 @
687a3222
...
@@ -26,13 +26,14 @@ limitations under the License. */
...
@@ -26,13 +26,14 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
class
Graph
{
class
Graph
{
public:
public:
explicit
Graph
(
const
ProgramDesc
&
program
);
explicit
Graph
(
const
ProgramDesc
&
program
);
virtual
~
Graph
()
{
virtual
~
Graph
()
{
for
(
auto
&
attr
:
attrs_
)
{
for
(
auto
&
attr
:
attrs_
)
{
attr_dels_
[
attr
.
first
]();
attr_dels_
[
attr
.
first
]();
}
}
attrs_
.
clear
();
attrs_
.
clear
();
...
@@ -40,12 +41,12 @@ class Graph {
...
@@ -40,12 +41,12 @@ class Graph {
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
...
@@ -54,29 +55,70 @@ class Graph {
...
@@ -54,29 +55,70 @@ class Graph {
};
};
}
}
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
{
return
node_set_
;
}
nodes
.
emplace_back
(
new
ir
::
Node
(
var_desc
));
return
nodes
.
back
().
get
();
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
)
{
return
AddNode
(
new
ir
::
Node
(
var_desc
));
}
// Create a normal runnable operator with OpDesc.
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
return
AddNode
(
new
ir
::
Node
(
op_desc
));
}
}
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
)
{
// Create a control dependency var that connects 2 operations. The
nodes
.
emplace_back
(
new
ir
::
Node
(
op_desc
));
// var doesn't hold any data. Other than that, it's no different from
return
nodes
.
back
().
get
();
// other var, considering dependency analysis.
ir
::
Node
*
CreateControlDepVar
()
{
// TODO(panyx0718): control var name should be really unique.
const
std
::
string
name
=
string
::
Sprintf
(
"%s@%llu"
,
ir
::
Node
::
kControlDepVarName
,
node_set_
.
size
());
return
AddNode
(
new
ir
::
Node
(
name
,
ir
::
Node
::
Type
::
kVariable
));
}
}
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
// A more free style way of creating a graph node. Mostly use for test
nodes
.
emplace_back
(
new
ir
::
Node
(
name
,
type
));
// or "copy" from another node. Avoid using it if possible.
return
nodes
.
back
().
get
();
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
)
{
return
AddNode
(
new
ir
::
Node
(
name
,
type
));
}
}
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
nodes
;
// Clear all node information of the graph and return the ownership of the
// nodes.
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ReleaseNodes
()
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ret
;
for
(
auto
&
n
:
nodes_
)
{
ret
.
emplace_back
(
n
.
second
.
release
());
}
nodes_
.
clear
();
node_set_
.
clear
();
return
ret
;
}
private:
private:
// This method takes ownership of `node`.
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
nodes_
[
node
].
reset
(
node
);
node_set_
.
insert
(
node
);
return
node
;
}
void
RemoveNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
!=
node_set_
.
end
());
node_set_
.
erase
(
node
);
nodes_
.
erase
(
node
);
}
// NOTE: program_ shouldn't be exposed to user.
// NOTE: program_ shouldn't be exposed to user.
const
ProgramDesc
&
program_
;
const
ProgramDesc
&
program_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
std
::
map
<
ir
::
Node
*
,
std
::
unique_ptr
<
ir
::
Node
>>
nodes_
;
std
::
unordered_set
<
ir
::
Node
*>
node_set_
;
};
};
bool
IsControlDepVar
(
const
ir
::
Node
&
var
);
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper.cc
0 → 100644
浏览文件 @
687a3222
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
{
void
SortHelper
(
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
ir
::
Node
*
node
,
std
::
unordered_set
<
ir
::
Node
*>
*
visited
,
std
::
vector
<
ir
::
Node
*>
*
ret
)
{
visited
->
insert
(
node
);
for
(
auto
adj
:
adj_list
.
at
(
node
))
{
if
(
visited
->
find
(
adj
)
==
visited
->
end
())
{
SortHelper
(
adj_list
,
adj
,
visited
,
ret
);
}
}
VLOG
(
3
)
<<
"topology sort insert: "
<<
node
->
Name
()
<<
reinterpret_cast
<
void
*>
(
node
)
<<
" input "
<<
node
->
inputs
.
size
();
ret
->
push_back
(
node
);
}
bool
HasCircleHelper
(
ir
::
Node
*
node
,
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
,
std
::
unordered_set
<
ir
::
Node
*>
*
visited
,
std
::
unordered_set
<
ir
::
Node
*>
*
in_trace
)
{
if
(
visited
->
find
(
node
)
==
visited
->
end
())
{
visited
->
insert
(
node
);
in_trace
->
insert
(
node
);
for
(
ir
::
Node
*
in
:
adj_list
.
at
(
node
))
{
if
(
visited
->
find
(
in
)
==
visited
->
end
()
&&
HasCircleHelper
(
in
,
adj_list
,
visited
,
in_trace
))
{
return
true
;
}
else
if
(
in_trace
->
find
(
in
)
!=
in_trace
->
end
())
{
return
true
;
}
}
}
in_trace
->
erase
(
node
);
return
false
;
}
bool
HasCircleInternal
(
const
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
&
adj_list
)
{
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
unordered_set
<
ir
::
Node
*>
in_trace
;
for
(
auto
&
adj
:
adj_list
)
{
if
(
HasCircleHelper
(
adj
.
first
,
adj_list
,
&
visited
,
&
in_trace
))
{
return
true
;
}
}
return
false
;
}
}
// namespace
bool
HasCircle
(
const
Graph
&
graph
)
{
return
HasCircleInternal
(
BuildOperationAdjList
(
graph
));
}
std
::
vector
<
ir
::
Node
*>
TopologySortOperations
(
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
=
BuildOperationAdjList
(
graph
);
PADDLE_ENFORCE
(
!
HasCircleInternal
(
adj_list
));
std
::
unordered_set
<
ir
::
Node
*>
visited
;
std
::
vector
<
ir
::
Node
*>
ret
;
for
(
auto
adj
:
adj_list
)
{
if
(
visited
.
find
(
adj
.
first
)
==
visited
.
end
())
{
SortHelper
(
adj_list
,
adj
.
first
,
&
visited
,
&
ret
);
}
}
return
ret
;
}
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildOperationAdjList
(
const
Graph
&
graph
)
{
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
graph
.
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
if
(
adj_list
.
find
(
n
)
==
adj_list
.
end
())
{
adj_list
[
n
]
=
std
::
unordered_set
<
ir
::
Node
*>
();
}
for
(
auto
&
var
:
n
->
inputs
)
{
for
(
auto
&
adj_n
:
var
->
inputs
)
{
PADDLE_ENFORCE
(
adj_n
->
NodeType
()
==
ir
::
Node
::
Type
::
kOperation
);
adj_list
[
n
].
insert
(
adj_n
);
VLOG
(
3
)
<<
"adj "
<<
adj_n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
adj_n
)
<<
" -> "
<<
n
->
Name
()
<<
reinterpret_cast
<
void
*>
(
n
)
<<
" via "
<<
var
->
Name
()
<<
reinterpret_cast
<
void
*>
(
var
);
}
}
}
return
adj_list
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper.h
0 → 100644
浏览文件 @
687a3222
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// Test if the graph contains circle.
bool
HasCircle
(
const
Graph
&
graph
);
// Topology Sort the operations in the graph from inputs to outputs.
// `graph` cannot contain circle.
std
::
vector
<
ir
::
Node
*>
TopologySortOperations
(
const
Graph
&
graph
);
// Build an adjacency list of operations for the `graph`.
std
::
map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
BuildOperationAdjList
(
const
Graph
&
graph
);
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_helper_test.cc
0 → 100644
浏览文件 @
687a3222
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
BuildCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o1
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o1
);
}
void
BuildCircleGraph2
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
o2
->
outputs
.
push_back
(
v2
);
o1
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o1
);
}
void
BuildNoCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o3
=
g
->
CreateEmptyNode
(
"op3"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o4
=
g
->
CreateEmptyNode
(
"op4"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o5
=
g
->
CreateEmptyNode
(
"op5"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v3
=
g
->
CreateEmptyNode
(
"var3"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v4
=
g
->
CreateEmptyNode
(
"var4"
,
Node
::
Type
::
kVariable
);
// o1->v1->o2
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
// o2->v2->o3
// o2->v2->o4
o2
->
outputs
.
push_back
(
v2
);
o3
->
inputs
.
push_back
(
v2
);
o4
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o3
);
v2
->
outputs
.
push_back
(
o4
);
// o2->v3->o5
o2
->
outputs
.
push_back
(
v3
);
o5
->
inputs
.
push_back
(
v3
);
v3
->
inputs
.
push_back
(
o2
);
v3
->
outputs
.
push_back
(
o5
);
// o3-v4->o5
o3
->
outputs
.
push_back
(
v4
);
o5
->
inputs
.
push_back
(
v4
);
v4
->
inputs
.
push_back
(
o3
);
v4
->
outputs
.
push_back
(
o5
);
}
TEST
(
GraphHelperTest
,
Basic
)
{
ProgramDesc
prog
;
Graph
g
(
prog
);
BuildCircleGraph
(
&
g
);
ASSERT_TRUE
(
HasCircle
(
g
));
Graph
g2
(
prog
);
BuildCircleGraph2
(
&
g2
);
ASSERT_TRUE
(
HasCircle
(
g2
));
auto
adj_list
=
BuildOperationAdjList
(
g2
);
for
(
auto
&
adj
:
adj_list
)
{
auto
&
adj_set
=
adj
.
second
;
if
(
adj
.
first
->
Name
()
==
"op1"
)
{
ASSERT_EQ
((
*
adj_set
.
begin
())
->
Name
(),
"op2"
);
}
else
if
(
adj
.
first
->
Name
()
==
"op2"
)
{
ASSERT_EQ
((
*
adj_set
.
begin
())
->
Name
(),
"op1"
);
}
else
{
ASSERT_TRUE
(
false
);
}
}
Graph
g3
(
prog
);
BuildNoCircleGraph
(
&
g3
);
ASSERT_FALSE
(
HasCircle
(
g3
));
auto
sorted
=
TopologySortOperations
(
g3
);
std
::
map
<
std
::
string
,
size_t
>
node_map
;
for
(
size_t
i
=
0
;
i
<
sorted
.
size
();
++
i
)
{
node_map
[
sorted
[
i
]
->
Name
()]
=
i
;
}
ASSERT_EQ
(
node_map
.
at
(
"op1"
),
0
);
ASSERT_EQ
(
node_map
.
at
(
"op2"
),
1
);
ASSERT_TRUE
(
node_map
.
at
(
"op3"
)
<
node_map
.
at
(
"op5"
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_test.cc
浏览文件 @
687a3222
...
@@ -76,6 +76,7 @@ TEST(GraphTest, Basic) {
...
@@ -76,6 +76,7 @@ TEST(GraphTest, Basic) {
op
->
SetType
(
"sum"
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
op
->
SetOutput
(
"Out"
,
{
"test_out"
});
op
->
SetOutput
(
"Out"
,
{
"test_out"
});
op
->
SetAttr
(
"op_role"
,
1
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test_a"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test_a"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test_b"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
prog
.
MutableBlock
(
0
)
->
Var
(
"test_b"
)
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
...
@@ -92,21 +93,22 @@ TEST(GraphTest, Basic) {
...
@@ -92,21 +93,22 @@ TEST(GraphTest, Basic) {
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
ASSERT_EQ
(
proto
::
VarType
::
LOD_TENSOR
,
prog
.
MutableBlock
(
0
)
->
Var
(
"test_out"
)
->
GetType
());
prog
.
MutableBlock
(
0
)
->
Var
(
"test_out"
)
->
GetType
());
std
::
unique_ptr
<
Graph
>
g
(
new
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
ASSERT_EQ
(
g
->
nodes
[
0
]
->
Name
(),
"sum"
);
std
::
vector
<
ir
::
Node
*>
nodes
(
g
->
Nodes
().
begin
(),
g
->
Nodes
().
end
());
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
0
]
->
Name
(),
"test_a"
);
for
(
ir
::
Node
*
n
:
nodes
)
{
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
1
]
->
Name
(),
"test_b"
);
if
(
n
->
Name
()
==
"sum"
)
{
ASSERT_EQ
(
g
->
nodes
[
0
]
->
inputs
[
2
]
->
Name
(),
"test_c"
);
ASSERT_EQ
(
n
->
inputs
.
size
(),
3
);
ASSERT_EQ
(
g
->
nodes
[
0
]
->
outputs
[
0
]
->
Name
(),
"test_out"
);
ASSERT_EQ
(
n
->
outputs
.
size
(),
1
);
ASSERT_EQ
(
g
->
nodes
[
1
]
->
Name
(),
"test_a"
);
}
else
if
(
n
->
Name
()
==
"test_a"
||
n
->
Name
()
==
"test_b"
||
ASSERT_EQ
(
g
->
nodes
[
1
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
n
->
Name
()
==
"test_c"
)
{
ASSERT_EQ
(
g
->
nodes
[
2
]
->
Name
(),
"test_b"
);
ASSERT_EQ
(
n
->
inputs
.
size
(),
0
);
ASSERT_EQ
(
g
->
nodes
[
2
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
n
->
outputs
.
size
(),
1
);
ASSERT_EQ
(
g
->
nodes
[
3
]
->
Name
(),
"test_c"
);
}
else
if
(
n
->
Name
()
==
"test_out"
)
{
ASSERT_EQ
(
g
->
nodes
[
3
]
->
outputs
[
0
]
->
Name
(),
"sum"
);
ASSERT_EQ
(
n
->
inputs
.
size
(),
1
);
ASSERT_EQ
(
g
->
nodes
[
4
]
->
Name
(),
"test_out"
);
ASSERT_EQ
(
n
->
outputs
.
size
(),
0
);
ASSERT_EQ
(
g
->
nodes
[
4
]
->
inputs
[
0
]
->
Name
(),
"sum"
);
}
ASSERT_EQ
(
g
->
nodes
.
size
(),
5
);
}
ASSERT_EQ
(
nodes
.
size
(),
5
);
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/node.cc
浏览文件 @
687a3222
...
@@ -15,5 +15,9 @@ limitations under the License. */
...
@@ -15,5 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{}
// namespace framework
namespace
framework
{
namespace
ir
{
const
char
Node
::
kControlDepVarName
[]
=
"__control_var"
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/node.h
浏览文件 @
687a3222
...
@@ -27,6 +27,8 @@ namespace ir {
...
@@ -27,6 +27,8 @@ namespace ir {
class
Node
{
class
Node
{
public:
public:
enum
class
Type
{
kOperation
,
kVariable
};
enum
class
Type
{
kOperation
,
kVariable
};
static
const
char
kControlDepVarName
[];
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
explicit
Node
(
const
std
::
string
&
name
,
Type
type
)
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
:
name_
(
name
),
var_desc_
(
nullptr
),
op_desc_
(
nullptr
),
type_
(
type
)
{}
...
@@ -50,6 +52,7 @@ class Node {
...
@@ -50,6 +52,7 @@ class Node {
PADDLE_ENFORCE
(
type_
==
Type
::
kVariable
);
PADDLE_ENFORCE
(
type_
==
Type
::
kVariable
);
return
var_desc_
;
return
var_desc_
;
}
}
OpDesc
*
Op
()
{
OpDesc
*
Op
()
{
PADDLE_ENFORCE
(
type_
==
Type
::
kOperation
);
PADDLE_ENFORCE
(
type_
==
Type
::
kOperation
);
return
op_desc_
;
return
op_desc_
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
687a3222
...
@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor(
...
@@ -132,7 +132,7 @@ ParallelExecutor::ParallelExecutor(
#endif
#endif
}
}
builder_
=
builder_factory
.
Create
();
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
main_program
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
...
...
paddle/fluid/inference/api/demo_ci/clean.sh
0 → 100755
浏览文件 @
687a3222
set
-x
cd
`
dirname
$0
`
rm
-rf
build/ data/
set
+x
paddle/fluid/memory/detail/buddy_allocator.cc
浏览文件 @
687a3222
...
@@ -15,6 +15,10 @@ limitations under the License. */
...
@@ -15,6 +15,10 @@ limitations under the License. */
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "paddle/fluid/memory/detail/buddy_allocator.h"
#include "glog/logging.h"
#include "glog/logging.h"
DEFINE_bool
(
free_idle_memory
,
false
,
"If it is true, Paddle will try to free idle memory trunks during "
"running time."
);
namespace
paddle
{
namespace
paddle
{
namespace
memory
{
namespace
memory
{
namespace
detail
{
namespace
detail
{
...
@@ -152,13 +156,14 @@ void BuddyAllocator::Free(void* p) {
...
@@ -152,13 +156,14 @@ void BuddyAllocator::Free(void* p) {
pool_
.
insert
(
pool_
.
insert
(
IndexSizeAddress
(
block
->
index
(
cache_
),
block
->
total_size
(
cache_
),
block
));
IndexSizeAddress
(
block
->
index
(
cache_
),
block
->
total_size
(
cache_
),
block
));
if
(
FLAGS_free_idle_memory
)
{
// Clean up if existing too much free memory
// Clean up if existing too much free memory
// Prefer freeing fallback allocation first
// Prefer freeing fallback allocation first
CleanIdleFallBackAlloc
();
CleanIdleFallBackAlloc
();
// Free normal allocation
// Free normal allocation
CleanIdleNormalAlloc
();
CleanIdleNormalAlloc
();
}
}
}
size_t
BuddyAllocator
::
Used
()
{
return
total_used_
;
}
size_t
BuddyAllocator
::
Used
()
{
return
total_used_
;
}
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
687a3222
...
@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
...
@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_DEPS
""
)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
set
(
DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
node
)
else
()
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib
)
set
(
DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib
node
)
if
(
WITH_BRPC_RDMA
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
687a3222
...
@@ -77,7 +77,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -77,7 +77,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// cudnn 7 can support groups, no need to do it mannually
// cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups
// FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1.
// rather than setting it to 1.
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionGroupCount
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionGroupCount
(
cudnn_conv_desc
,
groups
));
cudnn_conv_desc
,
groups
));
groups
=
1
;
groups
=
1
;
#endif
#endif
...
@@ -129,7 +129,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -129,7 +129,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
cudnn_output_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
workspace_size_limit
,
&
algo
));
workspace_size_limit
,
&
algo
));
...
@@ -140,18 +140,18 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -140,18 +140,18 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
if
(
dev_ctx
.
GetComputeCapability
()
>=
70
&&
if
(
dev_ctx
.
GetComputeCapability
()
>=
70
&&
std
::
type_index
(
typeid
(
T
))
==
std
::
type_index
(
typeid
(
T
))
==
std
::
type_index
(
typeid
(
platform
::
float16
)))
{
std
::
type_index
(
typeid
(
platform
::
float16
)))
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
cudnn_conv_desc
,
CUDNN_TENSOR_OP_MATH
));
cudnn_conv_desc
,
CUDNN_TENSOR_OP_MATH
));
// Currently tensor core is only enabled using this algo
// Currently tensor core is only enabled using this algo
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
;
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
;
}
else
{
}
else
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
}
}
#endif
#endif
// get workspace size able to allocate
// get workspace size able to allocate
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
handle
,
cudnn_input_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
algo
,
&
workspace_size_in_bytes
));
cudnn_output_desc
,
algo
,
&
workspace_size_in_bytes
));
// It is possible for float16 on Volta GPU to allocate more memory than
// It is possible for float16 on Volta GPU to allocate more memory than
...
@@ -165,7 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -165,7 +165,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward ---------------------
// ------------------- cudnn conv forward ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
cudnn_filter_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_filter_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
...
@@ -218,7 +218,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -218,7 +218,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// cudnn 7 can support groups, no need to do it mannually
// cudnn 7 can support groups, no need to do it mannually
// FIXME(typhoonzero): find a better way to disable groups
// FIXME(typhoonzero): find a better way to disable groups
// rather than setting it to 1.
// rather than setting it to 1.
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionGroupCount
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionGroupCount
(
cudnn_conv_desc
,
groups
));
cudnn_conv_desc
,
groups
));
groups
=
1
;
groups
=
1
;
#endif
#endif
...
@@ -273,7 +273,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -273,7 +273,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
if
(
input_grad
)
{
if
(
input_grad
)
{
if
(
FLAGS_cudnn_deterministic
)
{
if
(
FLAGS_cudnn_deterministic
)
{
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
handle
,
cudnn_filter_desc
,
handle
,
cudnn_filter_desc
,
// dyDesc: Handle to the previously initialized input
// dyDesc: Handle to the previously initialized input
...
@@ -289,7 +289,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -289,7 +289,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo
=
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
data_algo
=
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1
;
}
}
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
handle
,
cudnn_filter_desc
,
cudnn_output_grad_desc
,
handle
,
cudnn_filter_desc
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
cudnn_input_desc
,
data_algo
,
&
tmp_size
));
cudnn_conv_desc
,
cudnn_input_desc
,
data_algo
,
&
tmp_size
));
...
@@ -298,7 +298,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -298,7 +298,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
if
(
filter_grad
)
{
if
(
filter_grad
)
{
if
(
FLAGS_cudnn_deterministic
)
{
if
(
FLAGS_cudnn_deterministic
)
{
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_filter_desc
,
...
@@ -308,7 +308,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -308,7 +308,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
;
filter_algo
=
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1
;
}
}
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
handle
,
cudnn_input_desc
,
cudnn_output_grad_desc
,
cudnn_conv_desc
,
cudnn_filter_desc
,
filter_algo
,
&
tmp_size
));
cudnn_filter_desc
,
filter_algo
,
&
tmp_size
));
...
@@ -326,7 +326,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -326,7 +326,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
handle
,
&
alpha
,
cudnn_filter_desc
,
handle
,
&
alpha
,
cudnn_filter_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_output_grad_desc
,
filter_data
+
i
*
group_offset_filter
,
cudnn_output_grad_desc
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_conv_desc
,
data_algo
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_conv_desc
,
data_algo
,
...
@@ -339,7 +339,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -339,7 +339,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
filter_grad_data
=
filter_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset filter_grad.
// Because beta is zero, it is unnecessary to reset filter_grad.
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
handle
,
&
alpha
,
cudnn_input_desc
,
input_data
+
i
*
group_offset_in
,
cudnn_output_grad_desc
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_output_grad_desc
,
output_grad_data
+
i
*
group_offset_out
,
cudnn_conv_desc
,
filter_algo
,
cudnn_workspace
,
cudnn_conv_desc
,
filter_algo
,
cudnn_workspace
,
...
...
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
浏览文件 @
687a3222
...
@@ -87,7 +87,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -87,7 +87,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
// Get the algorithm
// Get the algorithm
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataAlgorithm
(
handle
,
cudnn_filter_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
handle
,
cudnn_filter_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
// dxDesc: Handle to the previously initialized output tensor
// dxDesc: Handle to the previously initialized output tensor
// descriptor.
// descriptor.
...
@@ -95,7 +95,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -95,7 +95,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
workspace_size_limit
,
&
algo
));
workspace_size_limit
,
&
algo
));
// get workspace size able to allocate
// get workspace size able to allocate
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
platform
::
dynload
::
cudnnGetConvolutionBackwardDataWorkspaceSize
(
handle
,
cudnn_filter_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
handle
,
cudnn_filter_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
cudnn_output_desc
,
algo
,
&
workspace_size_in_bytes
));
cudnn_output_desc
,
algo
,
&
workspace_size_in_bytes
));
...
@@ -110,7 +110,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -110,7 +110,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int
filter_offset
=
filter
->
numel
()
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
handle
,
&
alpha
,
cudnn_filter_desc
,
filter_data
+
filter_offset
*
g
,
handle
,
&
alpha
,
cudnn_filter_desc
,
filter_data
+
filter_offset
*
g
,
cudnn_input_desc
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
cudnn_input_desc
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
...
@@ -178,11 +178,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -178,11 +178,11 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
handle
=
dev_ctx
.
cudnn_handle
();
if
(
input_grad
)
{
if
(
input_grad
)
{
// choose backward algorithm for data
// choose backward algorithm for data
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardAlgorithm
(
handle
,
cudnn_output_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
handle
,
cudnn_output_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_input_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
cudnn_input_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
workspace_size_limit
,
&
data_algo
));
workspace_size_limit
,
&
data_algo
));
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionForwardWorkspaceSize
(
handle
,
cudnn_output_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
handle
,
cudnn_output_desc
,
cudnn_filter_desc
,
cudnn_conv_desc
,
cudnn_input_desc
,
data_algo
,
&
fwd_ws_size
));
cudnn_input_desc
,
data_algo
,
&
fwd_ws_size
));
workspace_size_in_bytes
=
std
::
max
(
workspace_size_in_bytes
,
fwd_ws_size
);
workspace_size_in_bytes
=
std
::
max
(
workspace_size_in_bytes
,
fwd_ws_size
);
...
@@ -190,7 +190,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -190,7 +190,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
if
(
filter_grad
)
{
if
(
filter_grad
)
{
// choose backward algorithm for filter
// choose backward algorithm for filter
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterAlgorithm
(
handle
,
cudnn_output_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
handle
,
cudnn_output_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
cudnn_filter_desc
,
cudnn_filter_desc
,
...
@@ -198,7 +198,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -198,7 +198,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
workspace_size_limit
,
&
filter_algo
));
workspace_size_limit
,
&
filter_algo
));
// get workspace for backwards filter algorithm
// get workspace for backwards filter algorithm
PADDLE
_ENFORCE
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
platform
::
dynload
::
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
handle
,
cudnn_output_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
handle
,
cudnn_output_desc
,
cudnn_input_desc
,
cudnn_conv_desc
,
cudnn_filter_desc
,
filter_algo
,
&
bwd_filter_ws_size
));
cudnn_filter_desc
,
filter_algo
,
&
bwd_filter_ws_size
));
...
@@ -222,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -222,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
handle
,
&
alpha
,
cudnn_output_desc
,
handle
,
&
alpha
,
cudnn_output_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_filter_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_filter_desc
,
filter_data
+
filter_offset
*
g
,
cudnn_conv_desc
,
data_algo
,
filter_data
+
filter_offset
*
g
,
cudnn_conv_desc
,
data_algo
,
...
@@ -237,7 +237,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -237,7 +237,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad.
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
// Gradient with respect to the filter
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardFilter
(
handle
,
&
alpha
,
cudnn_output_desc
,
handle
,
&
alpha
,
cudnn_output_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_input_desc
,
output_grad_data
+
output_grad_offset
*
g
,
cudnn_input_desc
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
filter_algo
,
input_data
+
input_offset
*
g
,
cudnn_conv_desc
,
filter_algo
,
...
...
paddle/fluid/operators/math/softmax.cu
浏览文件 @
687a3222
...
@@ -52,7 +52,7 @@ void SoftmaxCUDNNFunctor<T>::operator()(
...
@@ -52,7 +52,7 @@ void SoftmaxCUDNNFunctor<T>::operator()(
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
cudnnTensorDescriptor_t
cudnn_y_desc
=
cudnnTensorDescriptor_t
cudnn_y_desc
=
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
xDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnSoftmaxForward
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnSoftmaxForward
(
context
.
cudnn_handle
(),
CUDNN_SOFTMAX_ACCURATE
,
context
.
cudnn_handle
(),
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_MODE_INSTANCE
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_x_desc
,
CUDNN_SOFTMAX_MODE_INSTANCE
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_x_desc
,
X
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_y_desc
,
X
->
data
<
T
>
(),
CudnnDataType
<
T
>::
kZero
(),
cudnn_y_desc
,
...
@@ -83,7 +83,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
...
@@ -83,7 +83,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
dxDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
dxDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
cudnnTensorDescriptor_t
cudnn_ygrad_desc
=
cudnnTensorDescriptor_t
cudnn_ygrad_desc
=
dyDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
dyDesc
.
descriptor
<
T
>
(
layout
,
cudnn_tensor_dims
);
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnSoftmaxBackward
(
context
.
cudnn_handle
(),
CUDNN_SOFTMAX_ACCURATE
,
context
.
cudnn_handle
(),
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_MODE_INSTANCE
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_y_desc
,
CUDNN_SOFTMAX_MODE_INSTANCE
,
CudnnDataType
<
T
>::
kOne
(),
cudnn_y_desc
,
Y
->
data
<
T
>
(),
cudnn_ygrad_desc
,
YGrad
->
data
<
T
>
(),
Y
->
data
<
T
>
(),
cudnn_ygrad_desc
,
YGrad
->
data
<
T
>
(),
...
...
paddle/fluid/operators/pool_cudnn_op.cu.cc
浏览文件 @
687a3222
...
@@ -81,7 +81,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
...
@@ -81,7 +81,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm ---------------------
// ------------------- cudnn pool algorithm ---------------------
auto
handle
=
ctx
.
cuda_device_context
().
cudnn_handle
();
auto
handle
=
ctx
.
cuda_device_context
().
cudnn_handle
();
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnPoolingForward
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnPoolingForward
(
handle
,
cudnn_pool_desc
,
&
alpha
,
cudnn_input_desc
,
input_data
,
&
beta
,
handle
,
cudnn_pool_desc
,
&
alpha
,
cudnn_input_desc
,
input_data
,
&
beta
,
cudnn_output_desc
,
output_data
));
cudnn_output_desc
,
output_data
));
}
}
...
@@ -154,7 +154,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
...
@@ -154,7 +154,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
PADDLE
_ENFORCE
(
platform
::
dynload
::
cudnnPoolingBackward
(
CUDNN
_ENFORCE
(
platform
::
dynload
::
cudnnPoolingBackward
(
handle
,
cudnn_pool_desc
,
&
alpha
,
cudnn_output_desc
,
output_data
,
handle
,
cudnn_pool_desc
,
&
alpha
,
cudnn_output_desc
,
output_data
,
cudnn_output_desc
,
output_grad_data
,
cudnn_input_desc
,
input_data
,
cudnn_output_desc
,
output_grad_data
,
cudnn_input_desc
,
input_data
,
&
beta
,
cudnn_input_desc
,
input_grad_data
));
&
beta
,
cudnn_input_desc
,
input_grad_data
));
...
...
paddle/fluid/operators/send_recv_util.h
浏览文件 @
687a3222
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -22,7 +23,10 @@ inline bool NeedSend(const framework::Scope& scope,
...
@@ -22,7 +23,10 @@ inline bool NeedSend(const framework::Scope& scope,
const
std
::
string
&
varname
)
{
const
std
::
string
&
varname
)
{
// dummy variable is only used in parallel executor to represent
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
// some dependency relationship, we don't need to send/recv it.
if
(
varname
==
"dummy"
)
return
false
;
// TODO(paddle-dev): Why would parallel executor logic leaked into here?
if
(
varname
.
find
(
framework
::
ir
::
Node
::
kControlDepVarName
)
!=
std
::
string
::
npos
)
return
false
;
auto
*
var
=
scope
.
FindVar
(
varname
);
auto
*
var
=
scope
.
FindVar
(
varname
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find variable '%s' in the send side."
,
varname
);
varname
);
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
687a3222
...
@@ -62,9 +62,8 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
...
@@ -62,9 +62,8 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
#define CUDNN_ENFORCE(condition) \
#define CUDNN_ENFORCE(condition) \
do { \
do { \
cudnnStatus_t status = condition; \
cudnnStatus_t status = condition; \
if (status != CUDNN_STATUS_SUCCESS) { \
if (UNLIKELY(status != CUDNN_STATUS_SUCCESS)) { \
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \
PADDLE_THROW(::paddle::platform::cudnnGetErrorString(status)); \
PADDLE_THROW("cuDNN call failed"); \
} \
} \
} while (false)
} while (false)
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
687a3222
...
@@ -547,6 +547,7 @@ function test_fluid_inference_lib() {
...
@@ -547,6 +547,7 @@ function test_fluid_inference_lib() {
EOF
EOF
cd
${
PADDLE_ROOT
}
/paddle/fluid/inference/api/demo_ci
cd
${
PADDLE_ROOT
}
/paddle/fluid/inference/api/demo_ci
./run.sh
${
PADDLE_ROOT
}
${
WITH_MKL
:-
ON
}
${
WITH_GPU
:-
OFF
}
./run.sh
${
PADDLE_ROOT
}
${
WITH_MKL
:-
ON
}
${
WITH_GPU
:-
OFF
}
./clean.sh
fi
fi
}
}
...
...
python/paddle/fluid/__init__.py
浏览文件 @
687a3222
...
@@ -123,7 +123,7 @@ def __bootstrap__():
...
@@ -123,7 +123,7 @@ def __bootstrap__():
read_env_flags
=
[
read_env_flags
=
[
'use_pinned_memory'
,
'check_nan_inf'
,
'benchmark'
,
'warpctc_dir'
,
'use_pinned_memory'
,
'check_nan_inf'
,
'benchmark'
,
'warpctc_dir'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'initial_cpu_memory_in_mb'
,
'eager_delete_scope'
,
'use_mkldnn'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
'init_allocated_mem'
,
'free_idle_memory'
]
]
if
core
.
is_compiled_with_dist
():
if
core
.
is_compiled_with_dist
():
read_env_flags
.
append
(
'rpc_deadline'
)
read_env_flags
.
append
(
'rpc_deadline'
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
687a3222
...
@@ -1540,7 +1540,12 @@ class Program(object):
...
@@ -1540,7 +1540,12 @@ class Program(object):
def
inference_optimize
(
self
):
def
inference_optimize
(
self
):
"""
"""
This method will create a new program and change the :code:`is_test`
This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist.
2. Remove the :code:`read_op` if exists.
3. change the :code:`is_test`
attribute of operators to :code:`True`. All the :code:`Parameter`
attribute of operators to :code:`True`. All the :code:`Parameter`
information will be lost.
information will be lost.
...
@@ -1554,6 +1559,22 @@ class Program(object):
...
@@ -1554,6 +1559,22 @@ class Program(object):
# core.inference_optimize being fixed.
# core.inference_optimize being fixed.
res
=
Program
()
res
=
Program
()
res
.
desc
=
core
.
ProgramDesc
(
self
.
desc
)
res
.
desc
=
core
.
ProgramDesc
(
self
.
desc
)
# remove all readers and the read_op if exist
read_op_idx
=
0
root_block
=
res
.
desc
.
block
(
0
)
while
True
:
if
read_op_idx
>=
root_block
.
op_size
()
or
root_block
.
op
(
read_op_idx
).
type
()
==
'read'
:
break
read_op_idx
+=
1
if
read_op_idx
<
root_block
.
op_size
():
root_block
.
_remove_op
(
0
,
read_op_idx
+
1
)
for
var
in
root_block
.
all_vars
():
if
var
.
type
()
==
core
.
VarDesc
.
VarType
.
READER
:
root_block
.
_remove_var
(
var
.
name
())
# change all `is_test` attributes to True
for
i
in
xrange
(
res
.
desc
.
num_blocks
()):
for
i
in
xrange
(
res
.
desc
.
num_blocks
()):
block
=
res
.
desc
.
block
(
i
)
block
=
res
.
desc
.
block
(
i
)
for
j
in
xrange
(
block
.
op_size
()):
for
j
in
xrange
(
block
.
op_size
()):
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
687a3222
...
@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
...
@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
startup_var
)
if
for_parallel
:
main_prog_var
=
parallel
(
reader
=
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
687a3222
...
@@ -779,7 +779,9 @@ class DistributeTranspiler(object):
...
@@ -779,7 +779,9 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
prefetch_output_vars
},
outputs
=
{
"Out"
:
prefetch_output_vars
},
attrs
=
{
attrs
=
{
"epmap"
:
pserver_endpoints
,
"epmap"
:
pserver_endpoints
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
# FIXME(qiao) temporarily disable this config because prefetch
# is not act as other rpc op, it's more like a forward op
# RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
})
# insert concat_op
# insert concat_op
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录