Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a5c96af3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a5c96af3
编写于
7月 31, 2018
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into add_tensorrt_conv2d_converter
上级
f05c7fb8
baff71d5
变更
73
显示空白变更内容
内联
并排
Showing
73 changed file
with
1849 addition
and
569 deletion
+1849
-569
doc/fluid/design/ir/draft.md
doc/fluid/design/ir/draft.md
+97
-1
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-4
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+0
-3
paddle/fluid/framework/details/exception_holder.h
paddle/fluid/framework/details/exception_holder.h
+83
-0
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+81
-69
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+12
-25
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+5
-0
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+7
-6
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+6
-2
paddle/fluid/framework/details/ssa_graph_checker.cc
paddle/fluid/framework/details/ssa_graph_checker.cc
+10
-3
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+4
-16
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+3
-1
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+6
-3
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+9
-30
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+9
-26
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+3
-2
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+6
-3
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+8
-1
paddle/fluid/framework/ir/graph_viz_pass.cc
paddle/fluid/framework/ir/graph_viz_pass.cc
+72
-0
paddle/fluid/framework/ir/graph_viz_pass.h
paddle/fluid/framework/ir/graph_viz_pass.h
+38
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+28
-1
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+168
-2
paddle/fluid/framework/ir/pass_test.cc
paddle/fluid/framework/ir/pass_test.cc
+112
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+88
-17
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-1
paddle/fluid/inference/analysis/CMakeLists.txt
paddle/fluid/inference/analysis/CMakeLists.txt
+3
-0
paddle/fluid/inference/analysis/analyzer.cc
paddle/fluid/inference/analysis/analyzer.cc
+7
-0
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+13
-17
paddle/fluid/inference/analysis/analyzer_main.cc
paddle/fluid/inference/analysis/analyzer_main.cc
+33
-0
paddle/fluid/inference/analysis/analyzer_tester.cc
paddle/fluid/inference/analysis/analyzer_tester.cc
+6
-2
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+13
-0
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+2
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+2
-2
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
...nference/analysis/data_flow_graph_to_fluid_pass_tester.cc
+5
-5
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc
...fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc
+9
-3
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
...fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
+30
-6
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
.../fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
+1
-1
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
...nference/analysis/fluid_to_data_flow_graph_pass_tester.cc
+2
-1
paddle/fluid/inference/analysis/helper.h
paddle/fluid/inference/analysis/helper.h
+15
-0
paddle/fluid/inference/analysis/model_store_pass.cc
paddle/fluid/inference/analysis/model_store_pass.cc
+61
-0
paddle/fluid/inference/analysis/model_store_pass.h
paddle/fluid/inference/analysis/model_store_pass.h
+51
-0
paddle/fluid/inference/analysis/model_store_pass_tester.cc
paddle/fluid/inference/analysis/model_store_pass_tester.cc
+43
-0
paddle/fluid/inference/analysis/pass.h
paddle/fluid/inference/analysis/pass.h
+1
-0
paddle/fluid/inference/analysis/pass_manager_tester.cc
paddle/fluid/inference/analysis/pass_manager_tester.cc
+5
-2
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
+4
-4
paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass_tester.cc
...rence/analysis/tensorrt_subgraph_node_mark_pass_tester.cc
+3
-3
paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc
...fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc
+3
-4
paddle/fluid/inference/analysis/ut_helper.h
paddle/fluid/inference/analysis/ut_helper.h
+3
-18
paddle/fluid/inference/api/api_anakin_engine_tester.cc
paddle/fluid/inference/api/api_anakin_engine_tester.cc
+10
-8
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+11
-0
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
+12
-0
paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc
...luid/inference/api/api_tensorrt_subgraph_engine_tester.cc
+4
-5
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
+9
-8
paddle/fluid/inference/api/demo_ci/vis_demo.cc
paddle/fluid/inference/api/demo_ci/vis_demo.cc
+5
-5
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+1
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+25
-0
paddle/fluid/operators/math/im2col.cc
paddle/fluid/operators/math/im2col.cc
+10
-52
paddle/fluid/operators/math/im2col_cfo_cpu.h
paddle/fluid/operators/math/im2col_cfo_cpu.h
+252
-0
paddle/fluid/operators/math/im2col_test.cc
paddle/fluid/operators/math/im2col_test.cc
+103
-72
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+5
-26
paddle/fluid/operators/split_ids_op.h
paddle/fluid/operators/split_ids_op.h
+11
-4
paddle/fluid/platform/cuda_helper_test.cu
paddle/fluid/platform/cuda_helper_test.cu
+109
-74
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+10
-10
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+1
-1
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+13
-9
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+4
-5
python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py
...id/tests/unittests/test_memory_optimization_transpiler.py
+24
-0
python/paddle/fluid/tests/unittests/test_reshape_op.py
python/paddle/fluid/tests/unittests/test_reshape_op.py
+3
-3
python/paddle/fluid/tests/unittests/test_split_ids_op.py
python/paddle/fluid/tests/unittests/test_split_ids_op.py
+52
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+1
-0
tools/manylinux1/Dockerfile.x64
tools/manylinux1/Dockerfile.x64
+1
-1
未找到文件。
doc/fluid/design/ir/draft.md
浏览文件 @
a5c96af3
...
@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
...
@@ -64,6 +64,41 @@ can also contain other things that describe some properties of
the
`Graph`
or
`Graph`
nodes.
`Attribute`
can be passed
the
`Graph`
or
`Graph`
nodes.
`Attribute`
can be passed
across
`Pass`
. However, it should be used with care.
across
`Pass`
. However, it should be used with care.
```
cpp
class
Graph
{
public:
explicit
Graph
(
const
ProgramDesc
&
program
);
bool
Has
(
const
std
::
string
&
attr_name
)
const
;
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
const
std
::
unordered_set
<
ir
::
Node
*>
&
Nodes
()
const
;
// Create a normal variable with non-null VarDesc.
ir
::
Node
*
CreateVarNode
(
VarDesc
*
var_desc
);
// Create a normal runnable operator with OpDesc.
ir
::
Node
*
CreateOpNode
(
OpDesc
*
op_desc
);
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir
::
Node
*
CreateControlDepVar
();
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir
::
Node
*
CreateEmptyNode
(
const
std
::
string
&
name
,
ir
::
Node
::
Type
type
);
// Clear all node information of the graph and return the ownership of the
// nodes.
std
::
vector
<
std
::
unique_ptr
<
ir
::
Node
>>
ReleaseNodes
();
};
```
#### Pass
#### Pass
`Pass`
represents a transformation of
`Graph`
. Its input
`Pass`
represents a transformation of
`Graph`
. Its input
...
@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
...
@@ -71,6 +106,54 @@ is a `Graph` and its output is also a `Graph`. For example,
a
`Pass`
can simply print out the
`Graph`
. A
`Pass`
a
`Pass`
can simply print out the
`Graph`
. A
`Pass`
can also fuse some
`Graph`
's
`Node`
s.
can also fuse some
`Graph`
's
`Node`
s.
```
cpp
class
Pass
{
public:
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
// Some correctness check.
auto
new_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// Some correctness check.
return
new_graph
;
}
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
};
// In my_pass.cc
class
MyPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
// do something.
return
graph
;
}
}
REGISTER_PASS
(
my_pass
,
MyPass
)
.
RequirePassAttr
(
"places"
)
.
RequireGraphAttr
(
"dep_vars"
);
// To use the pass.
auto
my_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"my_pass"
);
graph
=
my_pass
->
Apply
(
std
::
move
(
graph
));
// Note: to force link my_pass.cc, in the code:
USE_PASS
(
my_pass
);
```
#### Optimize
#### Optimize
`Optimize`
contains a series of
`Pass`
with defined order.
`Optimize`
contains a series of
`Pass`
with defined order.
...
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
...
@@ -86,4 +169,17 @@ maintaining the original modeling logic.
*
Graph is transformed from raw model logic to a
*
Graph is transformed from raw model logic to a
form that is efficient to execute.
form that is efficient to execute.
Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
```
// Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor
auto graph = Graph(program);
graph = PassRegistry::Instance().Get("op_fuse_pass").Apply(std::move(grah));
// For more complex Pass, Optimize Process can provide Pass attributes.
auto mem_opt_pass = PassRegistry::Instance().Get("memory_optimization_pass");
mem_opt_pass.SetNotOwned<int>("optimize_level", 1);
mem_opt_pass->Apply(std::move(graph));
graph = PassRegistry::Instance().Get("multi_device_pass").Apply(std::move(grah));
graph = PassRegistry::Instance().Get("multi_device_check_pass").Apply(std::move(grah));
Executor exe;
exe.Run(graph);
```
paddle/fluid/API.spec
浏览文件 @
a5c96af3
...
@@ -170,6 +170,7 @@ paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], var
...
@@ -170,6 +170,7 @@ paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], var
paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
...
@@ -201,7 +202,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=
...
@@ -201,7 +202,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.Switch.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
...
@@ -225,17 +225,14 @@ paddle.fluid.layers.DynamicRNN.static_input ArgSpec(args=['self', 'x'], varargs=
...
@@ -225,17 +225,14 @@ paddle.fluid.layers.DynamicRNN.static_input ArgSpec(args=['self', 'x'], varargs=
paddle.fluid.layers.DynamicRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.DynamicRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.DynamicRNN.update_memory ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.DynamicRNN.update_memory ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.StaticRNN.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.StaticRNN.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.memory ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1))
paddle.fluid.layers.StaticRNN.memory ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1))
paddle.fluid.layers.StaticRNN.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.output ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step_input ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step_output ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.step_output ArgSpec(args=['self', 'o'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.update_memory ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.StaticRNN.update_memory ArgSpec(args=['self', 'mem', 'var'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.reorder_lod_tensor_by_rank ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.reorder_lod_tensor_by_rank ArgSpec(args=['x', 'rank_table'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.__init__ ArgSpec(args=['self', 'places', 'use_nccl', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.ParallelDo.__init__ ArgSpec(args=['self', 'places', 'use_nccl', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.ParallelDo.complete_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.do ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.do ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.get_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.get_parameters ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.ParallelDo.parent_block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
a5c96af3
...
@@ -99,7 +99,7 @@ else()
...
@@ -99,7 +99,7 @@ else()
endif
()
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
a5c96af3
...
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
...
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/
ssa_graph_builder_factory
.h
→
paddle/fluid/framework/details/
exception_holder
.h
浏览文件 @
a5c96af3
...
@@ -13,57 +13,69 @@
...
@@ -13,57 +13,69 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Scope
;
namespace
details
{
namespace
details
{
class
SSAGraphBuilderFactory
{
class
ExceptionHolder
{
public:
public:
SSAGraphBuilderFactory
(
const
std
::
vector
<
platform
::
Place
>&
places
,
void
Catch
(
const
platform
::
EnforceNotMet
&
exp
)
{
const
std
::
string
&
loss_var_name
,
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
exp
));
const
std
::
vector
<
Scope
*>&
local_scopes
,
type_
=
kEnforceNotMet
;
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
nullptr
;
#endif
}
}
#ifdef PADDLE_WITH_CUDA
void
Catch
(
const
platform
::
EOFException
&
exp
)
{
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
nccl_ctxs_
=
nccl_ctxs
;
// EOFException will not cover up existing EnforceNotMet.
if
(
exception_
.
get
()
==
nullptr
)
{
exception_
.
reset
(
new
platform
::
EOFException
(
exp
));
type_
=
kEOF
;
}
}
bool
ExceptionCatched
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
return
exception_
.
get
()
!=
nullptr
;
}
}
#endif
std
::
unique_ptr
<
SSAGraphBuilder
>
Create
();
void
Throw
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
switch
(
type_
)
{
case
kNone
:
break
;
case
kEnforceNotMet
:
{
auto
e
=
*
static_cast
<
platform
::
EnforceNotMet
*>
(
exception_
.
get
());
throw
e
;
break
;
}
case
kEOF
:
{
auto
e
=
*
static_cast
<
platform
::
EOFException
*>
(
exception_
.
get
());
throw
e
;
break
;
}
default:
LOG
(
FATAL
)
<<
"Unknown exception."
;
}
exception_
.
reset
();
type_
=
kNone
;
}
void
Clear
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
exception_
.
reset
();
type_
=
kNone
;
}
private:
private:
std
::
vector
<
platform
::
Place
>
places_
;
enum
ExceptionType
{
kNone
,
kEnforceNotMet
,
kEOF
};
std
::
string
loss_var_name_
;
ExceptionType
type_
{
kNone
};
std
::
unordered_set
<
std
::
string
>
param_names_
;
std
::
vector
<
Scope
*>
local_scopes_
;
BuildStrategy
strategy_
;
#ifdef PADDLE_WITH_CUDA
std
::
unique_ptr
<
std
::
exception
>
exception_
;
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
std
::
mutex
mu_
;
#endif
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
a5c96af3
...
@@ -34,30 +34,22 @@ namespace paddle {
...
@@ -34,30 +34,22 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
const
char
kLossVarName
[]
=
"loss_var_name"
;
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kParams
[]
=
"params"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
strategy_
=
Get
<
const
BuildStrategy
>
(
kStrategy
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
nccl_ctxs_
(
nccl_ctxs
),
strategy_
(
strategy
)
{
#else
MultiDevSSAGraphBuilder
::
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
loss_var_name_
(
loss_var_name
),
places_
(
places
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#endif
#endif
for
(
auto
&
p
:
params
)
{
for
(
auto
&
p
:
Get
<
const
std
::
unordered_set
<
std
::
string
>>
(
kParams
))
{
grad_names_
.
insert
(
GradVarName
(
p
));
grad_names_
.
insert
(
GradVarName
(
p
));
}
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
balance_vars_
.
resize
(
places_
.
size
(),
0
);
...
@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
...
@@ -72,7 +64,7 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
place_id
)
const
{
size_t
place_id
)
const
{
auto
p
=
places_
[
place_id
];
auto
p
=
places_
[
place_id
];
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
...
@@ -239,8 +231,9 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
return
sorted_ret
;
return
sorted_ret
;
}
}
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
MultiDevSSAGraphBuilder
::
Apply
Impl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
Init
();
// Give the topology sort order and rebuild the graph structure.
// Give the topology sort order and rebuild the graph structure.
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
std
::
vector
<
ir
::
Node
*>
sorted_ops
=
SortOpsAndDelayOptimizeOp
(
*
graph
);
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
...
@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -254,9 +247,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
Set
(
"vars"
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
"dep_vars"
,
new
GraphDepVars
);
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
"ops"
,
new
GraphOps
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kShardedVarDevice
,
new
ShardedVarDevice
);
// find send/recv vars so that we can place the distributed training
// find send/recv vars so that we can place the distributed training
// related op in the place 0
// related op in the place 0
...
@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -289,11 +283,12 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// the block.
// the block.
is_forwarding
=
false
;
is_forwarding
=
false
;
}
else
{
}
else
{
int
op_dev_id
=
GetOpDeviceID
(
node
);
int
op_dev_id
=
GetOpDeviceID
(
result
,
node
);
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
if
(
op_dev_id
!=
-
1
)
{
// This op only runs on one specific device.
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
&
result
,
node
,
op_dev_id
);
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
for
(
ir
::
Node
*
n
:
node
->
outputs
)
{
var_name_on_devices_
.
emplace
(
n
->
Name
(),
op_dev_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
n
->
Name
(),
op_dev_id
);
}
}
}
else
{
}
else
{
// This op runs on all devices, and its output may have parameter's
// This op runs on all devices, and its output may have parameter's
...
@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
...
@@ -330,7 +325,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
&
result
,
g_name
,
cur_device_id
);
var_name_on_devices_
.
emplace
(
g_name
,
cur_device_id
);
graph
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
g_name
,
cur_device_id
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
bcast_var_name_set
[
cur_device_id
].
emplace
(
p_name
);
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
...
@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
...
@@ -416,16 +412,16 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
local_scopes_
,
places_
);
#endif
#endif
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
auto
*
in
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
().
get
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
).
at
(
i
).
at
(
p_name
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
VarHandle
(
auto
*
out_var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
p_name
,
p
);
i
,
p_name
,
p
);
...
@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
...
@@ -437,7 +433,7 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
...
@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
...
@@ -446,20 +442,20 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilder
::
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
const
std
::
string
&
og
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
op_handle
->
AddInput
(
prev_grad
.
get
());
...
@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
...
@@ -475,20 +471,20 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result,
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
DataBalanceOpHandle
(
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"data_balance"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
d_name
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
auto
var
=
new
VarHandle
(
...
@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
...
@@ -512,7 +508,8 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return
is_pg_once
;
return
is_pg_once
;
}
}
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
int
MultiDevSSAGraphBuilder
::
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
return
-
1
;
}
}
...
@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
...
@@ -525,15 +522,17 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const {
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
PADDLE_ENFORCE_EQ
(
param_grad
.
size
(),
2U
);
int
dev_id
=
GetVarDeviceID
(
param_grad
[
1
]);
int
dev_id
=
GetVarDeviceID
(
graph
,
param_grad
[
1
]);
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
PADDLE_ENFORCE_NE
(
dev_id
,
-
1
,
"dev_id should not be -1.[%s, %s, %s]"
,
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
node
->
Op
()
->
Type
(),
param_grad
[
0
],
param_grad
[
1
]);
return
dev_id
;
return
dev_id
;
}
}
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
{
int
MultiDevSSAGraphBuilder
::
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
auto
got
=
var_name_on_devices_
.
find
(
varname
);
const
std
::
string
&
varname
)
const
{
return
got
==
var_name_on_devices_
.
end
()
?
-
1
:
got
->
second
;
auto
&
sharded_var_device
=
graph
.
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
);
auto
got
=
sharded_var_device
.
find
(
varname
);
return
got
==
sharded_var_device
.
end
()
?
-
1
:
got
->
second
;
}
}
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
)
const
{
...
@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
...
@@ -551,7 +550,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(ir::Graph *result) const {
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
local_scopes_
.
size
(),
local_scopes_
[
i
],
places_
[
i
],
communication_dev_ctx
);
communication_dev_ctx
);
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// factor. So it does not depend on any other operators.
...
@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
...
@@ -572,7 +571,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
}
...
@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -582,25 +581,25 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
int
dst_dev_id
)
const
{
int
dst_dev_id
)
const
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
.
get
());
op_handle
->
AddInput
(
prev_grad
.
get
());
}
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
"vars"
)[
dst_dev_id
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
var
=
auto
var
=
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
...
@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
...
@@ -613,11 +612,11 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
// on it.
// on it.
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
void
MultiDevSSAGraphBuilder
::
ConnectOp
(
ir
::
Graph
*
result
,
OpHandleBase
*
op
,
const
std
::
string
&
prev_op_name
)
const
{
const
std
::
string
&
prev_op_name
)
const
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
prev_op
:
result
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
if
(
prev_op
->
Name
()
==
prev_op_name
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
auto
*
dep_var
=
new
DummyVarHandle
(
result
->
CreateControlDepVar
());
prev_op
->
AddOutput
(
dep_var
);
prev_op
->
AddOutput
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
result
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
op
->
AddInput
(
dep_var
);
op
->
AddInput
(
dep_var
);
}
}
}
}
...
@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -638,20 +637,23 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
if
(
node
->
Op
()
->
Type
()
==
"split_byref"
||
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
node
->
Op
()
->
Type
()
==
"split_selected_rows"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
}
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
*
result
,
input_var_names
[
0
]);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -665,7 +667,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
CreateComputationalOp
(
result
,
node
,
op_dev_id
);
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
if
(
node
->
Op
()
->
Type
()
==
"concat"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"fetch_barrier"
);
"fetch_barrier"
);
}
}
}
}
...
@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -676,7 +678,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
int
op_dev_id
=
-
1
;
int
op_dev_id
=
-
1
;
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
node
->
inputs
[
0
]
->
Name
());
op_dev_id
=
GetVarDeviceID
(
*
result
,
node
->
inputs
[
0
]
->
Name
());
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
PADDLE_ENFORCE
(
!
ir
::
IsControlDepVar
(
*
node
->
inputs
[
0
]),
"This hack no longer holds, please fix."
);
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
...
@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -691,7 +693,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
}
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
}
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
...
@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -701,7 +704,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
}
}
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
output_var_names
);
for
(
auto
&
varname
:
output_var_names
)
{
for
(
auto
&
varname
:
output_var_names
)
{
var_name_on_devices_
.
emplace
(
varname
,
op_dev_id
);
result
->
Get
<
ShardedVarDevice
>
(
kShardedVarDevice
)
.
emplace
(
varname
,
op_dev_id
);
}
}
}
else
{
}
else
{
// send_barrier and fetch_barrier op can be scheduled on device 0
// send_barrier and fetch_barrier op can be scheduled on device 0
...
@@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
...
@@ -711,18 +715,18 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
PADDLE_ENFORCE
(
op_dev_id
!=
-
1
,
"can not find the right place for rpc op: %s"
,
node
->
Op
()
->
Type
());
node
->
Op
()
->
Type
());
result
->
Get
<
GraphOps
>
(
"ops"
).
emplace_back
(
new
RPCOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
RPCOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
// TODO(panyx0718): This might not be needed anymore.
// TODO(panyx0718): This might not be needed anymore.
if
(
node
->
Op
()
->
Type
()
==
"send_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"send_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"send"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"recv"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"send_barrier"
);
"send_barrier"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
"ops"
).
back
().
get
(),
"recv"
);
ConnectOp
(
result
,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
().
get
(),
"recv"
);
}
else
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
}
else
if
(
node
->
Op
()
->
Type
()
==
"send"
)
{
// do nothing
// do nothing
}
else
{
}
else
{
...
@@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
...
@@ -744,3 +748,11 @@ bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_device_pass
,
paddle
::
framework
::
details
::
MultiDevSSAGraphBuilder
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLossVarName
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kParams
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
);
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
a5c96af3
...
@@ -31,39 +31,27 @@ class Scope;
...
@@ -31,39 +31,27 @@ class Scope;
namespace
details
{
namespace
details
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
class
MultiDevSSAGraphBuilder
:
public
SSAGraphBuilder
{
public:
protected:
#ifdef PADDLE_WITH_CUDA
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
BuildStrategy
&
strategy
);
#else
MultiDevSSAGraphBuilder
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
BuildStrategy
&
strategy
);
#endif
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
override
;
private:
private:
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
device_id
)
const
;
size_t
device_id
)
const
;
void
Init
()
const
;
private:
private:
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unordered_set
<
std
::
string
>
grad_names_
;
mutable
std
::
unordered_set
<
std
::
string
>
grad_names_
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
int
GetVarDeviceID
(
const
ir
::
Graph
&
graph
,
const
std
::
string
&
varname
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
void
CreateRPCOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
...
@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -97,7 +85,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
string
&
og
,
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
;
int
GetOpDeviceID
(
ir
::
Node
*
node
)
const
;
int
GetOpDeviceID
(
const
ir
::
Graph
&
graph
,
ir
::
Node
*
node
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
...
@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
...
@@ -113,9 +101,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
const
std
::
vector
<
std
::
string
>
&
var_names
)
const
;
private:
private:
BuildStrategy
strategy_
;
mutable
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
int
>
var_name_on_devices_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
mutable
std
::
vector
<
int64_t
>
balance_vars_
;
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
a5c96af3
...
@@ -40,6 +40,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -40,6 +40,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
ExecutionStrategy
strategy
,
std
::
vector
<
Scope
*>
local_scopes
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
vector
<
VariableInfo
>
var_infos
,
std
::
vector
<
platform
::
Place
>
places
,
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
std
::
unique_ptr
<
SSAGraphExecutor
>&&
underlying_executor
);
const
ir
::
Graph
&
Graph
()
const
override
{
return
underlying_executor_
->
Graph
();
}
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
override
;
private:
private:
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
a5c96af3
...
@@ -18,7 +18,7 @@ namespace paddle {
...
@@ -18,7 +18,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
void
SSAGraphBuilder
::
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
continue
;
...
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
...
@@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
}
}
}
}
}
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
...
@@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
VarHandle
*
SSAGraphBuilder
::
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
var_holder
.
empty
())
{
...
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
...
@@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir
::
Node
*
new_node
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
"vars"
)[
place_offset
][
new_node
->
Name
()];
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
size_t
version
=
vars
.
size
();
auto
var
=
auto
var
=
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
...
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
...
@@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
).
emplace
(
dummy_leaf
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
}
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
a5c96af3
...
@@ -39,21 +39,25 @@ namespace details {
...
@@ -39,21 +39,25 @@ namespace details {
typedef
std
::
vector
<
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
GraphVars
;
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
typedef
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// all operators. NOTE that even we use a vector here, the operators is
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
// unordered.
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
typedef
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
typedef
std
::
unordered_map
<
std
::
string
,
int
>
ShardedVarDevice
;
const
char
kShardedVarDevice
[]
=
"sharded_var_device"
;
class
SSAGraphBuilder
:
public
ir
::
Pass
{
class
SSAGraphBuilder
:
public
ir
::
Pass
{
public:
public:
SSAGraphBuilder
()
{}
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
~
SSAGraphBuilder
()
{}
virtual
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
DISABLE_COPY_AND_ASSIGN
(
SSAGraphBuilder
);
protected:
protected:
...
...
paddle/fluid/framework/details/ssa_graph_checker.cc
浏览文件 @
a5c96af3
...
@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
};
};
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
.
get
());
insert_pending_var
(
version_pair
.
get
());
...
@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
insert_pending_var
(
var
.
get
());
insert_pending_var
(
var
.
get
());
}
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
.
get
());
ready_ops
.
insert
(
op
.
get
());
}
else
{
}
else
{
...
@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
...
@@ -85,3 +85,10 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_device_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphOps
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kShardedVarDevice
);
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
a5c96af3
...
@@ -23,26 +23,14 @@ namespace framework {
...
@@ -23,26 +23,14 @@ namespace framework {
namespace
details
{
namespace
details
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithChecker
:
public
SSAGraphBuilder
{
public:
protected:
explicit
SSAGraghBuilderWithChecker
(
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
builder_
(
std
::
move
(
builder
))
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
graph
;
return
new_graph
;
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
}
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
private:
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
a5c96af3
...
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
...
@@ -32,7 +32,9 @@ class SSAGraphExecutor {
virtual
~
SSAGraphExecutor
();
virtual
~
SSAGraphExecutor
();
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
=
0
;
virtual
const
ir
::
Graph
&
Graph
()
const
=
0
;
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
浏览文件 @
a5c96af3
...
@@ -22,7 +22,7 @@ namespace details {
...
@@ -22,7 +22,7 @@ namespace details {
template
<
typename
Callback
>
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
"vars"
))
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
callback
(
*
pair2
);
...
@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
...
@@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
}
}
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
callback
(
*
var
);
callback
(
*
var
);
}
}
}
}
...
@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
});
size_t
op_id
=
0
;
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph
.
Get
<
GraphOps
>
(
kGraphOps
))
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -81,3 +81,6 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_device_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
);
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
a5c96af3
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
#pragma once
#pragma once
#include <fstream>
#include <iosfwd>
#include <iosfwd>
#include <ostream>
#include <string>
#include <string>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
...
@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
...
@@ -34,38 +36,15 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
public:
protected:
SSAGraghBuilderWithPrinter
(
std
::
ostream
&
sout
,
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ref_
(
sout
)
{}
SSAGraghBuilderWithPrinter
(
std
::
unique_ptr
<
std
::
ostream
>&&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
printer_
->
Print
(
*
new_graph
,
stream_ref_
);
new
std
::
ofstream
(
Get
<
const
std
::
string
>
(
"debug_graphviz_path"
)));
return
new_graph
;
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
return
graph
;
}
}
int
GetVarDeviceID
(
const
std
::
string
&
var_name
)
const
override
{
return
builder_
->
GetVarDeviceID
(
var_name
);
}
private:
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
std
::
ostream
>
stream_ptr_
;
std
::
ostream
&
stream_ref_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
a5c96af3
...
@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
// Transform SSAGraph to pending_ops & pending_vars
// Transform SSAGraph to pending_ops & pending_vars
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
());
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
());
}
}
}
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
"dep_vars"
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
());
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
());
}
}
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
"ops"
))
{
for
(
auto
&
op
:
graph_
->
Get
<
details
::
GraphOps
>
(
details
::
kGraphOps
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
());
ready_ops
.
insert
(
op
.
get
());
}
else
{
}
else
{
...
@@ -83,7 +83,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -83,7 +83,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Clean run context
// Clean run context
run_op_futures_
.
clear
();
run_op_futures_
.
clear
();
exception_
.
reset
();
exception_
holder_
.
Clear
();
// Step 3. Execution
// Step 3. Execution
while
(
!
pending_vars
.
empty
())
{
while
(
!
pending_vars
.
empty
())
{
...
@@ -103,23 +103,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -103,23 +103,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto
cur_ready_vars
=
ready_vars
.
PopAll
(
1
,
&
timeout
);
auto
cur_ready_vars
=
ready_vars
.
PopAll
(
1
,
&
timeout
);
if
(
timeout
)
{
if
(
timeout
)
{
std
::
unique_lock
<
std
::
mutex
>
l
(
exception_mu_
);
if
(
exception_holder_
.
ExceptionCatched
())
{
if
(
exception_
)
{
l
.
unlock
();
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
run_op_future
.
wait
();
run_op_future
.
wait
();
}
}
l
.
lock
();
exception_holder_
.
Throw
();
std
::
exception
*
exp
=
exception_
.
get
();
if
(
dynamic_cast
<
platform
::
EOFException
*>
(
exp
))
{
auto
e
=
*
static_cast
<
platform
::
EOFException
*>
(
exp
);
throw
e
;
}
else
if
(
dynamic_cast
<
platform
::
EnforceNotMet
*>
(
exp
))
{
auto
e
=
*
static_cast
<
platform
::
EnforceNotMet
*>
(
exp
);
throw
e
;
}
else
{
LOG
(
FATAL
)
<<
"Unknown exception."
;
}
}
else
{
}
else
{
continue
;
continue
;
}
}
...
@@ -162,7 +150,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
...
@@ -162,7 +150,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
"vars"
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
...
@@ -229,14 +217,9 @@ void ThreadedSSAGraphExecutor::RunOp(
...
@@ -229,14 +217,9 @@ void ThreadedSSAGraphExecutor::RunOp(
ready_var_q
->
Extend
(
op
->
Outputs
());
ready_var_q
->
Extend
(
op
->
Outputs
());
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
"Signal posted"
;
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
"Signal posted"
;
}
catch
(
platform
::
EOFException
ex
)
{
}
catch
(
platform
::
EOFException
ex
)
{
std
::
lock_guard
<
std
::
mutex
>
l
(
exception_mu_
);
exception_holder_
.
Catch
(
ex
);
// EOFException will not cover up existing EnforceNotMet.
if
(
exception_
.
get
()
==
nullptr
)
{
exception_
.
reset
(
new
platform
::
EOFException
(
ex
));
}
}
catch
(
platform
::
EnforceNotMet
ex
)
{
}
catch
(
platform
::
EnforceNotMet
ex
)
{
std
::
lock_guard
<
std
::
mutex
>
l
(
exception_mu_
);
exception_holder_
.
Catch
(
ex
);
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
}
catch
(...)
{
}
catch
(...)
{
LOG
(
FATAL
)
<<
"Unknown exception catched"
;
LOG
(
FATAL
)
<<
"Unknown exception catched"
;
}
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
a5c96af3
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <functional>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
...
@@ -42,6 +43,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -42,6 +43,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
std
::
unique_ptr
<
ir
::
Graph
>
&&
graph
);
const
ir
::
Graph
&
Graph
()
const
override
{
return
*
graph_
;
}
// Run a SSAGraph by a thread pool
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
// Use topological sort algorithm
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
...
@@ -58,8 +60,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -58,8 +60,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
platform
::
DeviceContextPool
fetch_ctxs_
;
platform
::
DeviceContextPool
fetch_ctxs_
;
std
::
mutex
exception_mu_
;
ExceptionHolder
exception_holder_
;
std
::
unique_ptr
<
std
::
exception
>
exception_
;
std
::
atomic
<
int
>
running_ops_
;
std
::
atomic
<
int
>
running_ops_
;
void
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
void
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
a5c96af3
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph SRCS graph.cc DEPS node
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
pass SRCS pass.cc DEPS graph node
)
cc_library
(
pass SRCS pass.cc DEPS graph node graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph op_registry
)
cc_library
(
graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry
)
cc_test
(
pass_test SRCS pass_test.cc DEPS graph pass graph_helper
)
cc_test
(
graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry
)
cc_test
(
graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry
)
paddle/fluid/framework/ir/graph.h
浏览文件 @
a5c96af3
...
@@ -40,14 +40,21 @@ class Graph {
...
@@ -40,14 +40,21 @@ class Graph {
attr_dels_
.
clear
();
attr_dels_
.
clear
();
}
}
bool
Has
(
const
std
::
string
&
attr_name
)
const
{
return
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
();
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
Has
(
attr_name
),
"%s attr not registered for graph."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
}
template
<
typename
AttrType
>
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the graph"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
...
...
paddle/fluid/framework/ir/graph_viz_pass.cc
0 → 100644
浏览文件 @
a5c96af3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
const
char
kGraphVizPath
[]
=
"graph_viz_path"
;
std
::
unique_ptr
<
ir
::
Graph
>
GraphVizPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
const
std
::
string
graph_viz_path
=
Get
<
std
::
string
>
(
kGraphVizPath
);
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
graph_viz_path
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
ostream
&
sout
=
*
fout
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
ir
::
Node
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kVariable
)
continue
;
size_t
cur_var_id
=
var_id
++
;
vars
[
n
]
=
cur_var_id
;
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
]"
<<
std
::
endl
;
}
size_t
op_id
=
0
;
for
(
const
ir
::
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
NodeType
()
!=
ir
::
Node
::
Type
::
kOperation
)
continue
;
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
n
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
n
->
inputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
n
->
outputs
)
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
graph_viz_pass
,
paddle
::
framework
::
ir
::
GraphVizPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraphVizPath
);
paddle/fluid/framework/ir/graph_viz_pass.h
0 → 100644
浏览文件 @
a5c96af3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
GraphVizPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/pass.cc
浏览文件 @
a5c96af3
...
@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{}
// namespace framework
namespace
framework
{
namespace
ir
{
std
::
unique_ptr
<
Graph
>
Pass
::
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
PADDLE_ENFORCE
(
!
applied_
,
"Pass can only Apply() once."
);
PADDLE_ENFORCE
(
graph
.
get
(),
"graph passed to Pass::Apply() cannot be empty."
);
for
(
const
std
::
string
&
attr
:
required_pass_attrs_
)
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr
)
!=
attrs_
.
end
(),
"Required pass atrribute %s not set."
,
attr
);
}
for
(
const
std
::
string
&
attr
:
required_graph_attrs_
)
{
PADDLE_ENFORCE
(
graph
->
Has
(
attr
),
"Required graph atrribute %s not set."
,
attr
);
}
auto
applied_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE
(
!
HasCircle
(
*
applied_graph
),
"Illegal Pass. Generated graph shouldn't has cycle."
);
applied_
=
true
;
return
applied_graph
;
}
PassRegistry
&
PassRegistry
::
Instance
()
{
static
PassRegistry
g_pass_info_map
;
return
g_pass_info_map
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/pass.h
浏览文件 @
a5c96af3
...
@@ -14,21 +14,187 @@ limitations under the License. */
...
@@ -14,21 +14,187 @@ limitations under the License. */
#pragma once
#pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
template
<
typename
PassType
>
struct
PassRegistrar
;
class
Pass
{
class
Pass
{
public:
public:
Pass
()
=
default
;
Pass
()
=
default
;
virtual
~
Pass
()
{}
virtual
~
Pass
()
{
for
(
auto
&
attr
:
attrs_
)
{
if
(
attr_dels_
.
find
(
attr
.
first
)
!=
attr_dels_
.
end
())
{
attr_dels_
[
attr
.
first
]();
}
}
attrs_
.
clear
();
attr_dels_
.
clear
();
}
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
;
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
"%s attr not registered for pass."
,
attr_name
);
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
,
"%s already set in the pass"
,
attr_name
);
attrs_
[
attr_name
]
=
attr
;
attr_dels_
[
attr_name
]
=
[
attr
,
attr_name
]()
{
VLOG
(
3
)
<<
"deleting "
<<
attr_name
;
delete
attr
;
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
attrs_
[
attr_name
]
=
attr
;
}
protected:
virtual
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
private:
template
<
typename
PassType
>
friend
struct
PassRegistrar
;
void
RegisterRequiredPassAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_pass_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
void
RegisterRequiredGraphAttrs
(
const
std
::
unordered_set
<
std
::
string
>
&
attrs
)
{
required_graph_attrs_
.
insert
(
attrs
.
begin
(),
attrs
.
end
());
}
mutable
bool
applied_
{
false
};
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
std
::
function
<
void
(
void
)
>>
attr_dels_
;
};
using
PassCreator
=
std
::
function
<
std
::
unique_ptr
<
Pass
>
()
>
;
class
Registrar
{
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void
Touch
()
{}
};
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
class
PassRegistry
{
public:
static
PassRegistry
&
Instance
();
bool
Has
(
const
std
::
string
&
pass_type
)
const
{
return
map_
.
find
(
pass_type
)
!=
map_
.
end
();
}
void
Insert
(
const
std
::
string
&
pass_type
,
const
PassCreator
&
pass_creator
)
{
PADDLE_ENFORCE
(
!
Has
(
pass_type
),
"Pass %s has been registered"
,
pass_type
);
map_
.
insert
({
pass_type
,
pass_creator
});
}
std
::
unique_ptr
<
Pass
>
Get
(
const
std
::
string
&
pass_type
)
const
{
PADDLE_ENFORCE
(
Has
(
pass_type
),
"Pass %s has not been registered"
,
pass_type
);
return
map_
.
at
(
pass_type
)();
}
private:
PassRegistry
()
=
default
;
std
::
unordered_map
<
std
::
string
,
PassCreator
>
map_
;
DISABLE_COPY_AND_ASSIGN
(
PassRegistry
);
};
};
template
<
typename
PassType
>
struct
PassRegistrar
:
public
Registrar
{
explicit
PassRegistrar
(
const
char
*
pass_type
)
{
PADDLE_ENFORCE
(
!
PassRegistry
::
Instance
().
Has
(
pass_type
),
"'%s' is registered more than once."
,
pass_type
);
PassRegistry
::
Instance
().
Insert
(
pass_type
,
[
this
]()
->
std
::
unique_ptr
<
Pass
>
{
std
::
unique_ptr
<
Pass
>
pass
(
new
PassType
());
pass
->
RegisterRequiredPassAttrs
(
this
->
required_pass_attrs_
);
pass
->
RegisterRequiredGraphAttrs
(
this
->
required_graph_attrs_
);
return
pass
;
});
}
PassRegistrar
<
PassType
>
&
RequirePassAttr
(
const
std
::
string
&
attr
)
{
required_pass_attrs_
.
insert
(
attr
);
return
*
this
;
}
PassRegistrar
<
PassType
>
&
RequireGraphAttr
(
const
std
::
string
&
attr
)
{
required_graph_attrs_
.
insert
(
attr
);
return
*
this
;
}
private:
std
::
unordered_set
<
std
::
string
>
required_pass_attrs_
;
std
::
unordered_set
<
std
::
string
>
required_graph_attrs_
;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \
TouchPassRegistrar_##pass_type()
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/pass_test.cc
0 → 100644
浏览文件 @
a5c96af3
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
BuildCircleGraph
(
Graph
*
g
)
{
ir
::
Node
*
o1
=
g
->
CreateEmptyNode
(
"op1"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
o2
=
g
->
CreateEmptyNode
(
"op2"
,
Node
::
Type
::
kOperation
);
ir
::
Node
*
v1
=
g
->
CreateEmptyNode
(
"var1"
,
Node
::
Type
::
kVariable
);
ir
::
Node
*
v2
=
g
->
CreateEmptyNode
(
"var2"
,
Node
::
Type
::
kVariable
);
o1
->
outputs
.
push_back
(
v1
);
o2
->
inputs
.
push_back
(
v1
);
v1
->
inputs
.
push_back
(
o1
);
v1
->
outputs
.
push_back
(
o2
);
o2
->
outputs
.
push_back
(
v2
);
o1
->
inputs
.
push_back
(
v2
);
v2
->
inputs
.
push_back
(
o2
);
v2
->
outputs
.
push_back
(
o1
);
}
class
TestPass
:
public
Pass
{
protected:
std
::
unique_ptr
<
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
Graph
>
graph
)
const
{
graph
->
Set
<
int
>
(
"copy_test_pass_attr"
,
new
int
);
graph
->
Set
<
int
>
(
"copy_test_graph_attr"
,
new
int
);
int
test_pass_attr
=
this
->
Get
<
int
>
(
"test_pass_attr"
);
graph
->
Get
<
int
>
(
"copy_test_pass_attr"
)
=
test_pass_attr
+
1
;
int
test_graph_attr
=
graph
->
Get
<
int
>
(
"test_graph_attr"
);
graph
->
Get
<
int
>
(
"copy_test_graph_attr"
)
=
test_graph_attr
+
1
;
return
graph
;
}
};
TEST
(
PassTest
,
TestPassAttrCheck
)
{
ProgramDesc
prog
;
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
std
::
unique_ptr
<
Graph
>
graph
(
new
Graph
(
prog
));
std
::
string
exception
;
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"test_pass_attr not set"
)
!=
exception
.
npos
);
int
val
=
1
;
graph
.
reset
(
new
Graph
(
prog
));
pass
->
SetNotOwned
<
int
>
(
"test_pass_attr"
,
&
val
);
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"test_graph_attr not set"
)
!=
exception
.
npos
);
graph
.
reset
(
new
Graph
(
prog
));
graph
->
Set
<
int
>
(
"test_graph_attr"
,
new
int
);
graph
->
Get
<
int
>
(
"test_graph_attr"
)
=
1
;
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_test_pass_attr"
),
2
);
ASSERT_EQ
(
graph
->
Get
<
int
>
(
"copy_test_graph_attr"
),
2
);
try
{
graph
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"Pass can only Apply() once"
)
!=
exception
.
npos
);
pass
=
PassRegistry
::
Instance
().
Get
(
"test_pass"
);
pass
->
SetNotOwned
<
int
>
(
"test_pass_attr"
,
&
val
);
graph
.
reset
(
new
Graph
(
prog
));
BuildCircleGraph
(
graph
.
get
());
graph
->
Set
<
int
>
(
"test_graph_attr"
,
new
int
);
graph
->
Get
<
int
>
(
"test_graph_attr"
)
=
2
;
try
{
auto
tmp
=
pass
->
Apply
(
std
::
move
(
graph
));
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
exception
=
std
::
string
(
e
.
what
());
}
ASSERT_TRUE
(
exception
.
find
(
"shouldn't has cycle"
)
!=
exception
.
npos
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
test_pass
,
paddle
::
framework
::
ir
::
TestPass
)
.
RequirePassAttr
(
"test_pass_attr"
)
.
RequireGraphAttr
(
"test_graph_attr"
);
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
a5c96af3
...
@@ -19,19 +19,80 @@ limitations under the License. */
...
@@ -19,19 +19,80 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
unique_ptr
<
ir
::
Graph
>
ApplyParallelExecutorPass
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>
&
param_names
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
bool
use_cuda
,
#ifdef PADDLE_WITH_CUDA
const
BuildStrategy
&
strategy
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
#else
const
BuildStrategy
&
strategy
)
{
#endif
// Convert the program to graph.
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
// Apply a graph viz pass to record a graph.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy
.
debug_graphviz_path_
.
c_str
(),
"_original_graph"
);
viz_pass
->
Set
<
std
::
string
>
(
"graph_viz_path"
,
new
std
::
string
(
graph_path
));
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
// Convert graph to run on multi-devices.
auto
multi_device_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_pass"
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
multi_device_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
graph
=
multi_device_pass
->
Apply
(
std
::
move
(
graph
));
// Apply a graph print pass to record a graph with device info.
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
multi_device_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_print_pass"
);
multi_device_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy
.
debug_graphviz_path_
);
multi_device_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
graph
=
multi_device_print_pass
->
Apply
(
std
::
move
(
graph
));
}
// Verify that the graph is correct for multi-device executor.
auto
multi_device_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_check_pass"
);
graph
=
multi_device_check_pass
->
Apply
(
std
::
move
(
graph
));
return
graph
;
}
class
ParallelExecutorPrivate
{
class
ParallelExecutorPrivate
{
public:
public:
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
explicit
ParallelExecutorPrivate
(
const
std
::
vector
<
platform
::
Place
>
&
places
)
...
@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor(
...
@@ -119,21 +180,19 @@ ParallelExecutor::ParallelExecutor(
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
details
::
SSAGraphBuilderFactory
builder_factory
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
if
(
member_
->
use_cuda_
)
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
,
member_
->
nccl_ctxs_
.
get
());
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
ApplyParallelExecutorPass
(
main_program
,
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
build_strategy
);
#endif
#endif
}
builder_
=
builder_factory
.
Create
();
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
graph
=
builder_
->
Apply
(
std
::
move
(
graph
));
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
...
@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -146,11 +205,18 @@ void ParallelExecutor::BCastParamsToDevices(
// the initializing bcast, all vars would be bcast from device(0),
// the initializing bcast, all vars would be bcast from device(0),
// otherwise
// otherwise
// bcast from the specified device.
// bcast from the specified device.
bool
initializing
=
builder_
.
get
()
==
nullptr
?
true
:
false
;
bool
initializing
=
member_
->
executor_
?
false
:
true
;
for
(
auto
&
var
:
vars
)
{
for
(
auto
&
var
:
vars
)
{
int
var_dev_id
=
int
var_dev_id
=
-
1
;
builder_
.
get
()
==
nullptr
?
-
1
:
builder_
->
GetVarDeviceID
(
var
);
if
(
member_
->
executor_
)
{
auto
&
sharded_var_device
=
member_
->
executor_
->
Graph
().
Get
<
details
::
ShardedVarDevice
>
(
details
::
kShardedVarDevice
);
if
(
sharded_var_device
.
find
(
var
)
!=
sharded_var_device
.
end
())
{
var_dev_id
=
sharded_var_device
.
at
(
var
);
}
}
if
(
!
initializing
&&
var_dev_id
==
-
1
)
continue
;
if
(
!
initializing
&&
var_dev_id
==
-
1
)
continue
;
framework
::
Variable
*
main_var
=
nullptr
;
framework
::
Variable
*
main_var
=
nullptr
;
...
@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -286,3 +352,8 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_device_pass
);
USE_PASS
(
multi_device_check_pass
);
USE_PASS
(
multi_device_print_pass
);
paddle/fluid/framework/parallel_executor.h
浏览文件 @
a5c96af3
...
@@ -70,7 +70,6 @@ class ParallelExecutor {
...
@@ -70,7 +70,6 @@ class ParallelExecutor {
private:
private:
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
std
::
unique_ptr
<
details
::
SSAGraphBuilder
>
builder_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/inference/analysis/CMakeLists.txt
浏览文件 @
a5c96af3
...
@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
...
@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc
tensorrt_subgraph_node_mark_pass.cc
analyzer.cc
analyzer.cc
helper.cc
helper.cc
model_store_pass.cc
DEPS framework_proto proto_desc
)
DEPS framework_proto proto_desc
)
cc_test
(
test_node SRCS node_tester.cc DEPS analysis
)
cc_test
(
test_node SRCS node_tester.cc DEPS analysis
)
cc_test
(
test_dot SRCS dot_tester.cc DEPS analysis
)
cc_test
(
test_dot SRCS dot_tester.cc DEPS analysis
)
cc_binary
(
inference_analyzer SRCS analyzer_main.cc DEPS analysis
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
set
(
PYTHON_TESTS_DIR
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/tests
)
...
@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
...
@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
inference_analysis_test
(
test_pass_manager SRCS pass_manager_tester.cc
)
inference_analysis_test
(
test_pass_manager SRCS pass_manager_tester.cc
)
inference_analysis_test
(
test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc
)
inference_analysis_test
(
test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc
)
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
)
inference_analysis_test
(
test_analyzer SRCS analyzer_tester.cc
)
inference_analysis_test
(
test_model_store_pass SRCS model_store_pass_tester.cc
)
paddle/fluid/inference/analysis/analyzer.cc
浏览文件 @
a5c96af3
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
...
@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
...
@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
DEFINE_string
(
inference_analysis_graphviz_log_root
,
"./"
,
DEFINE_string
(
inference_analysis_graphviz_log_root
,
"./"
,
"Graphviz debuger for data flow graphs."
);
"Graphviz debuger for data flow graphs."
);
DEFINE_string
(
inference_analysis_output_storage_path
,
""
,
"optimized model output path"
);
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
...
@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
...
@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
AddPass
(
"tensorrt-subgraph"
,
new
TensorRTSubGraphPass
(
trt_teller
));
AddPass
(
"tensorrt-subgraph"
,
new
TensorRTSubGraphPass
(
trt_teller
));
}
}
AddPass
(
"data-flow-graph-to-fluid"
,
new
DataFlowGraphToFluidPass
);
AddPass
(
"data-flow-graph-to-fluid"
,
new
DataFlowGraphToFluidPass
);
if
(
!
FLAGS_inference_analysis_output_storage_path
.
empty
())
{
AddPass
(
"model-store-pass"
,
new
ModelStorePass
);
}
}
}
std
::
string
repr
()
const
override
{
return
"dfg-pass-manager"
;
}
std
::
string
repr
()
const
override
{
return
"dfg-pass-manager"
;
}
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
a5c96af3
...
@@ -16,28 +16,23 @@ limitations under the License. */
...
@@ -16,28 +16,23 @@ limitations under the License. */
/*
/*
* This file contains Analyzer, an class that exposed as a library that analyze
* This file contains Analyzer, an class that exposed as a library that analyze
* and optimize
* and optimize Fluid ProgramDesc for inference. Similar to LLVM, it has
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to
* multiple flags to
* control whether
* control whether an process is applied on the program.
* an process is applied on the program.
*
*
* The processes are called Passes in analysis, the Passes are placed in a
* The processes are called Passes in analysis, the Passes are placed in a
* pipeline, the first
* pipeline, the first Pass is the FluidToDataFlowGraphPass which transforms a
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to
* Fluid ProgramDesc to
* a data flow
* a data flow graph, the last Pass is DataFlowGraphToFluidPass which transforms
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow
* a data flow graph to a Fluid ProgramDesc. The passes in the middle of the
* graph to a
* pipeline can be any Passes
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes
* which take a node or data flow graph as input.
* which take a
* node or data flow graph as input.
*
*
* The Analyzer can be used in two methods, the first is a executable file which
* The Analyzer can be used in two methods, the first is a executable file which
* can be used to
* can be used to pre-process the inference model and can be controlled by
* pre-process the inference model and can be controlled by passing difference
* passing difference command flags;
* command flags;
* the other way is to compose inside the inference API as a runtime pre-process
* the other way is to compose inside the inference API as a runtime pre-process
* phase in the
* phase in the inference service.
* inference service.
*/
*/
#include <gflags/gflags.h>
#include <gflags/gflags.h>
...
@@ -50,6 +45,7 @@ namespace paddle {
...
@@ -50,6 +45,7 @@ namespace paddle {
// flag if not available.
// flag if not available.
DECLARE_bool
(
inference_analysis_enable_tensorrt_subgraph_engine
);
DECLARE_bool
(
inference_analysis_enable_tensorrt_subgraph_engine
);
DECLARE_string
(
inference_analysis_graphviz_log_root
);
DECLARE_string
(
inference_analysis_graphviz_log_root
);
DECLARE_string
(
inference_analysis_output_storage_path
);
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
...
...
paddle/fluid/inference/analysis/analyzer_main.cc
0 → 100644
浏览文件 @
a5c96af3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements analysizer -- an executation help to analyze and
* optimize trained model.
*/
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
using
paddle
::
inference
::
analysis
::
Analyzer
;
using
paddle
::
inference
::
analysis
::
Argument
;
Argument
argument
;
Analyzer
analyzer
;
analyzer
.
Run
(
&
argument
);
return
0
;
}
paddle/fluid/inference/analysis/analyzer_tester.cc
浏览文件 @
a5c96af3
...
@@ -20,14 +20,18 @@ namespace paddle {
...
@@ -20,14 +20,18 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
TEST
_F
(
DFG_Test
er
,
analysis_without_tensorrt
)
{
TEST
(
Analyz
er
,
analysis_without_tensorrt
)
{
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
=
false
;
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
=
false
;
Argument
argument
;
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
FLAGS_inference_model_dir
));
Analyzer
analyser
;
Analyzer
analyser
;
analyser
.
Run
(
&
argument
);
analyser
.
Run
(
&
argument
);
}
}
TEST
_F
(
DFG_Test
er
,
analysis_with_tensorrt
)
{
TEST
(
Analyz
er
,
analysis_with_tensorrt
)
{
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
=
true
;
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine
=
true
;
Argument
argument
;
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
FLAGS_inference_model_dir
));
Analyzer
analyser
;
Analyzer
analyser
;
analyser
.
Run
(
&
argument
);
analyser
.
Run
(
&
argument
);
}
}
...
...
paddle/fluid/inference/analysis/argument.h
浏览文件 @
a5c96af3
...
@@ -36,6 +36,16 @@ namespace analysis {
...
@@ -36,6 +36,16 @@ namespace analysis {
* All the fields should be registered here for clearness.
* All the fields should be registered here for clearness.
*/
*/
struct
Argument
{
struct
Argument
{
Argument
()
=
default
;
explicit
Argument
(
const
std
::
string
&
fluid_model_dir
)
:
fluid_model_dir
(
new
std
::
string
(
fluid_model_dir
))
{}
// The directory of the trained model.
std
::
unique_ptr
<
std
::
string
>
fluid_model_dir
;
// The path of `__model__` and `param`, this is used when the file name of
// model and param is changed.
std
::
unique_ptr
<
std
::
string
>
fluid_model_program_path
;
std
::
unique_ptr
<
std
::
string
>
fluid_model_param_path
;
// The graph that process by the Passes or PassManagers.
// The graph that process by the Passes or PassManagers.
std
::
unique_ptr
<
DataFlowGraph
>
main_dfg
;
std
::
unique_ptr
<
DataFlowGraph
>
main_dfg
;
...
@@ -44,6 +54,9 @@ struct Argument {
...
@@ -44,6 +54,9 @@ struct Argument {
// The processed program desc.
// The processed program desc.
std
::
unique_ptr
<
framework
::
proto
::
ProgramDesc
>
transformed_program_desc
;
std
::
unique_ptr
<
framework
::
proto
::
ProgramDesc
>
transformed_program_desc
;
// The output storage path of ModelStorePass.
std
::
unique_ptr
<
std
::
string
>
model_output_store_path
;
};
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
...
...
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
a5c96af3
...
@@ -36,6 +36,8 @@ namespace analysis {
...
@@ -36,6 +36,8 @@ namespace analysis {
/*
/*
* DataFlowGraph - A container of Value and Function Nodes.
* DataFlowGraph - A container of Value and Function Nodes.
*
* This is the base graph for any other type of graphs, such as SSA or CFG.
*/
*/
struct
DataFlowGraph
{
struct
DataFlowGraph
{
NodeMap
nodes
;
NodeMap
nodes
;
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
a5c96af3
...
@@ -20,7 +20,7 @@ namespace inference {
...
@@ -20,7 +20,7 @@ namespace inference {
namespace
analysis
{
namespace
analysis
{
TEST
(
DataFlowGraph
,
BFS
)
{
TEST
(
DataFlowGraph
,
BFS
)
{
auto
desc
=
LoadProgramDesc
();
auto
desc
=
LoadProgramDesc
(
FLAGS_inference_model_dir
+
"/__model__"
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
dfg
.
Build
();
dfg
.
Build
();
...
@@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) {
...
@@ -44,7 +44,7 @@ TEST(DataFlowGraph, BFS) {
}
}
TEST
(
DataFlowGraph
,
DFS
)
{
TEST
(
DataFlowGraph
,
DFS
)
{
auto
desc
=
LoadProgramDesc
();
auto
desc
=
LoadProgramDesc
(
FLAGS_inference_model_dir
+
"/__model__"
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
dfg
.
Build
();
dfg
.
Build
();
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
GraphTraits
<
DataFlowGraph
>
trait
(
&
dfg
);
...
...
paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass_tester.cc
浏览文件 @
a5c96af3
...
@@ -26,21 +26,21 @@ namespace paddle {
...
@@ -26,21 +26,21 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
TEST
_F
(
DFG_Tester
,
Test
)
{
TEST
(
DataFlowGraph
,
Test
)
{
DataFlowGraph
graph
;
Argument
argument
(
FLAGS_inference_model_dir
)
;
FluidToDataFlowGraphPass
pass0
;
FluidToDataFlowGraphPass
pass0
;
DataFlowGraphToFluidPass
pass1
;
DataFlowGraphToFluidPass
pass1
;
ASSERT_TRUE
(
pass0
.
Initialize
(
&
argument
));
ASSERT_TRUE
(
pass0
.
Initialize
(
&
argument
));
ASSERT_TRUE
(
pass1
.
Initialize
(
&
argument
));
ASSERT_TRUE
(
pass1
.
Initialize
(
&
argument
));
pass0
.
Run
(
&
graph
);
pass0
.
Run
(
argument
.
main_dfg
.
get
()
);
pass1
.
Run
(
&
graph
);
pass1
.
Run
(
argument
.
main_dfg
.
get
()
);
pass0
.
Finalize
();
pass0
.
Finalize
();
pass1
.
Finalize
();
pass1
.
Finalize
();
LOG
(
INFO
)
<<
graph
.
nodes
.
size
();
LOG
(
INFO
)
<<
argument
.
main_dfg
->
nodes
.
size
();
}
}
};
// namespace analysis
};
// namespace analysis
...
...
paddle/fluid/inference/analysis/dfg_graphviz_draw_pass_tester.cc
浏览文件 @
a5c96af3
...
@@ -23,12 +23,18 @@ namespace paddle {
...
@@ -23,12 +23,18 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
TEST_F
(
DFG_Tester
,
dfg_graphviz_draw_pass_tester
)
{
TEST
(
DFG_GraphvizDrawPass
,
dfg_graphviz_draw_pass_tester
)
{
auto
dfg
=
ProgramDescToDFG
(
*
argument
.
origin_program_desc
);
Argument
argument
(
FLAGS_inference_model_dir
);
FluidToDataFlowGraphPass
pass0
;
ASSERT_TRUE
(
pass0
.
Initialize
(
&
argument
));
pass0
.
Run
(
argument
.
main_dfg
.
get
());
// auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
DFG_GraphvizDrawPass
::
Config
config
(
"./"
,
"test"
);
DFG_GraphvizDrawPass
::
Config
config
(
"./"
,
"test"
);
DFG_GraphvizDrawPass
pass
(
config
);
DFG_GraphvizDrawPass
pass
(
config
);
pass
.
Initialize
(
&
argument
);
pass
.
Initialize
(
&
argument
);
pass
.
Run
(
&
dfg
);
pass
.
Run
(
argument
.
main_dfg
.
get
()
);
// test content
// test content
std
::
ifstream
file
(
"./0-graph_test.dot"
);
std
::
ifstream
file
(
"./0-graph_test.dot"
);
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc
浏览文件 @
a5c96af3
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <glog/logging.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -25,8 +26,20 @@ namespace analysis {
...
@@ -25,8 +26,20 @@ namespace analysis {
bool
FluidToDataFlowGraphPass
::
Initialize
(
Argument
*
argument
)
{
bool
FluidToDataFlowGraphPass
::
Initialize
(
Argument
*
argument
)
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
);
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
);
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
origin_program_desc
);
if
(
argument
->
origin_program_desc
)
{
PADDLE_ENFORCE
(
argument
);
LOG
(
WARNING
)
<<
"argument's origin_program_desc is already set, might "
"duplicate called"
;
}
if
(
!
argument
->
fluid_model_program_path
)
{
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_dir
);
argument
->
fluid_model_program_path
.
reset
(
new
std
::
string
(
*
argument
->
fluid_model_dir
+
"/__model__"
));
}
ANALYSIS_ARGUMENT_CHECK_FIELD
(
argument
->
fluid_model_program_path
);
auto
program
=
LoadProgramDesc
(
*
argument
->
fluid_model_program_path
);
argument
->
origin_program_desc
.
reset
(
new
framework
::
proto
::
ProgramDesc
(
program
));
if
(
!
argument
->
main_dfg
)
{
if
(
!
argument
->
main_dfg
)
{
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
argument
->
main_dfg
.
reset
(
new
DataFlowGraph
);
}
}
...
@@ -40,6 +53,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
...
@@ -40,6 +53,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE
(
graph
);
PADDLE_ENFORCE
(
graph
);
PADDLE_ENFORCE
(
desc_
);
PADDLE_ENFORCE
(
desc_
);
// insert vars
// insert vars
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std
::
unordered_map
<
std
::
string
,
size_t
>
var2id
;
std
::
unordered_map
<
std
::
string
,
size_t
>
var2id
;
auto
&
main_block
=
desc_
->
blocks
(
framework
::
kRootBlockIndex
);
auto
&
main_block
=
desc_
->
blocks
(
framework
::
kRootBlockIndex
);
for
(
int
i
=
0
;
i
<
main_block
.
vars_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
main_block
.
vars_size
();
i
++
)
{
...
@@ -51,6 +66,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
...
@@ -51,6 +66,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
var2id
[
var
.
name
()]
=
v
->
id
();
var2id
[
var
.
name
()]
=
v
->
id
();
}
}
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std
::
unordered_set
<
Node
*>
unique_written_vars
;
for
(
int
i
=
0
;
i
<
main_block
.
ops_size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
main_block
.
ops_size
();
i
++
)
{
const
auto
&
op
=
main_block
.
ops
(
i
);
const
auto
&
op
=
main_block
.
ops
(
i
);
auto
*
o
=
graph
->
nodes
.
Create
(
Node
::
Type
::
kFunction
);
auto
*
o
=
graph
->
nodes
.
Create
(
Node
::
Type
::
kFunction
);
...
@@ -62,33 +86,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
...
@@ -62,33 +86,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
o
->
SetPbMsg
(
op
.
SerializeAsString
());
o
->
SetPbMsg
(
op
.
SerializeAsString
());
// set inputs and outputs
// set inputs and outputs
std
::
unordered_set
<
Node
*>
inlinks
;
for
(
int
j
=
0
;
j
<
op
.
inputs_size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
op
.
inputs_size
();
j
++
)
{
auto
&
in_var
=
op
.
inputs
(
j
);
auto
&
in_var
=
op
.
inputs
(
j
);
for
(
int
k
=
0
;
k
<
in_var
.
arguments_size
();
k
++
)
{
for
(
int
k
=
0
;
k
<
in_var
.
arguments_size
();
k
++
)
{
auto
*
in
=
graph
->
nodes
.
GetMutable
(
var2id
.
at
(
in_var
.
arguments
(
k
)));
auto
*
in
=
graph
->
nodes
.
GetMutable
(
var2id
.
at
(
in_var
.
arguments
(
k
)));
in
->
outlinks
.
push_back
(
o
);
in
->
outlinks
.
push_back
(
o
);
o
->
inlinks
.
push_back
(
in
);
o
->
inlinks
.
push_back
(
in
);
inlinks
.
insert
(
in
);
}
}
}
}
for
(
int
j
=
0
;
j
<
op
.
outputs_size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
op
.
outputs_size
();
j
++
)
{
auto
&
out_var
=
op
.
outputs
(
j
);
auto
&
out_var
=
op
.
outputs
(
j
);
for
(
int
k
=
0
;
k
<
out_var
.
arguments_size
();
k
++
)
{
for
(
int
k
=
0
;
k
<
out_var
.
arguments_size
();
k
++
)
{
auto
*
out
=
graph
->
nodes
.
GetMutable
(
var2id
[
out_var
.
arguments
(
k
)]);
auto
*
out
=
graph
->
nodes
.
GetMutable
(
var2id
[
out_var
.
arguments
(
k
)]);
if
(
inlink
s
.
count
(
out
))
{
if
(
unique_written_var
s
.
count
(
out
))
{
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto
*
out_alias
=
graph
->
nodes
.
Create
(
Node
::
Type
::
kValue
);
auto
*
out_alias
=
graph
->
nodes
.
Create
(
Node
::
Type
::
kValue
);
out_alias
->
SetName
(
out
->
name
());
out_alias
->
SetName
(
out
->
name
());
out_alias
->
SetPbDesc
(
out
->
pb_desc
());
out_alias
->
SetPbDesc
(
out
->
pb_desc
());
out_alias
->
SetPbMsg
(
out
->
pb_msg
());
out_alias
->
SetPbMsg
(
out
->
pb_msg
());
var2id
[
out_alias
->
name
()]
=
out_alias
->
id
();
// update a -> a0
var2id
[
out_alias
->
name
()]
=
out_alias
->
id
();
// update variable's alias Node
LOG
(
INFO
)
<<
"loop found in graph, create SSA alias node ["
LOG
(
INFO
)
<<
"loop found in graph, create SSA alias node ["
<<
out_alias
->
repr
()
<<
"] for ["
<<
out
->
repr
()
<<
"]"
;
<<
out_alias
->
repr
()
<<
"] for ["
<<
out
->
repr
()
<<
"]"
;
out
=
out_alias
;
out
=
out_alias
;
}
}
out
->
inlinks
.
push_back
(
o
);
out
->
inlinks
.
push_back
(
o
);
o
->
outlinks
.
push_back
(
out
);
o
->
outlinks
.
push_back
(
out
);
unique_written_vars
.
insert
(
out
);
}
}
}
}
}
}
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h
浏览文件 @
a5c96af3
...
@@ -30,7 +30,7 @@ namespace inference {
...
@@ -30,7 +30,7 @@ namespace inference {
namespace
analysis
{
namespace
analysis
{
/*
/*
* Transform a FluidDesc to a
data flow graph
.
* Transform a FluidDesc to a
SSA
.
*/
*/
class
FluidToDataFlowGraphPass
final
:
public
DataFlowGraphPass
{
class
FluidToDataFlowGraphPass
final
:
public
DataFlowGraphPass
{
public:
public:
...
...
paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
浏览文件 @
a5c96af3
...
@@ -21,8 +21,9 @@ namespace paddle {
...
@@ -21,8 +21,9 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
TEST
_F
(
DFG_Tester
,
Ini
t
)
{
TEST
(
FluidToDataFlowGraphPass
,
Tes
t
)
{
FluidToDataFlowGraphPass
pass
;
FluidToDataFlowGraphPass
pass
;
Argument
argument
(
FLAGS_inference_model_dir
);
pass
.
Initialize
(
&
argument
);
pass
.
Initialize
(
&
argument
);
pass
.
Run
(
argument
.
main_dfg
.
get
());
pass
.
Run
(
argument
.
main_dfg
.
get
());
// Analysis is sensitive to ProgramDesc, careful to change the original model.
// Analysis is sensitive to ProgramDesc, careful to change the original model.
...
...
paddle/fluid/inference/analysis/helper.h
浏览文件 @
a5c96af3
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <cstdio>
#include <cstdio>
#include <fstream>
#include <string>
#include <string>
#include <typeindex>
#include <typeindex>
#include <unordered_map>
#include <unordered_map>
...
@@ -136,6 +137,20 @@ static void ExecShellCommand(const std::string &cmd, std::string *message) {
...
@@ -136,6 +137,20 @@ static void ExecShellCommand(const std::string &cmd, std::string *message) {
}
}
}
}
static
framework
::
proto
::
ProgramDesc
LoadProgramDesc
(
const
std
::
string
&
model_path
)
{
std
::
ifstream
fin
(
model_path
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
fin
.
is_open
(),
"Cannot open file %s"
,
model_path
);
fin
.
seekg
(
0
,
std
::
ios
::
end
);
std
::
string
buffer
(
fin
.
tellg
(),
' '
);
fin
.
seekg
(
0
,
std
::
ios
::
beg
);
fin
.
read
(
&
buffer
[
0
],
buffer
.
size
());
fin
.
close
();
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
buffer
);
return
program_desc
;
}
}
// namespace analysis
}
// namespace analysis
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/inference/analysis/model_store_pass.cc
0 → 100644
浏览文件 @
a5c96af3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include <stdio.h>
#include <stdlib.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/argument.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
void
ModelStorePass
::
Run
(
DataFlowGraph
*
x
)
{
if
(
!
argument_
->
fluid_model_param_path
)
{
PADDLE_ENFORCE_NOT_NULL
(
argument_
->
fluid_model_dir
);
argument_
->
fluid_model_param_path
.
reset
(
new
std
::
string
(
*
argument_
->
fluid_model_dir
+
"param"
));
}
PADDLE_ENFORCE_NOT_NULL
(
argument_
->
model_output_store_path
);
// Directly copy param file to destination.
std
::
stringstream
ss
;
// NOTE these commands only works on linux.
ss
<<
"mkdir -p "
<<
*
argument_
->
model_output_store_path
;
LOG
(
INFO
)
<<
"run command: "
<<
ss
.
str
();
PADDLE_ENFORCE_EQ
(
system
(
ss
.
str
().
c_str
()),
0
);
ss
.
str
(
""
);
ss
<<
"cp "
<<
*
argument_
->
fluid_model_dir
<<
"/*"
<<
" "
<<
*
argument_
->
model_output_store_path
;
LOG
(
INFO
)
<<
"run command: "
<<
ss
.
str
();
PADDLE_ENFORCE_EQ
(
system
(
ss
.
str
().
c_str
()),
0
);
// Store program
PADDLE_ENFORCE_NOT_NULL
(
argument_
->
transformed_program_desc
,
"program desc is not transformed, should call "
"DataFlowGraphToFluidPass first."
);
const
std
::
string
program_output_path
=
*
argument_
->
model_output_store_path
+
"/__model__"
;
std
::
ofstream
file
(
program_output_path
,
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
file
.
is_open
(),
"failed to open %s to write."
,
program_output_path
);
const
std
::
string
serialized_message
=
argument_
->
transformed_program_desc
->
SerializeAsString
();
file
.
write
(
serialized_message
.
c_str
(),
serialized_message
.
size
());
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/
framework/details/ssa_graph_builder_factory.cc
→
paddle/fluid/
inference/analysis/model_store_pass.h
浏览文件 @
a5c96af3
...
@@ -12,39 +12,40 @@
...
@@ -12,39 +12,40 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
/*
#include <fstream>
* This file defines ModelStorePass, which store the runtime DFG to a Paddle
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
* model in the disk, and that model can be reloaded for prediction.
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
*/
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
inference
{
namespace
details
{
namespace
analysis
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
class
ModelStorePass
:
public
DataFlowGraphPass
{
#ifdef PADDLE_WITH_CUDA
public:
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
bool
Initialize
(
Argument
*
argument
)
override
{
local_scopes_
,
nccl_ctxs_
,
strategy_
)
if
(
!
argument
)
{
#else
LOG
(
ERROR
)
<<
"invalid argument"
;
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
return
false
;
local_scopes_
,
strategy_
)
#endif
);
// NOLINT
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
strategy_
.
debug_graphviz_path_
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
unique_ptr
<
GraphvizSSAGraphPrinter
>
graphviz_printer
(
new
GraphvizSSAGraphPrinter
());
res
.
reset
(
new
SSAGraghBuilderWithPrinter
(
std
::
move
(
fout
),
std
::
move
(
graphviz_printer
),
std
::
move
(
res
)));
}
}
res
.
reset
(
new
SSAGraghBuilderWithChecker
(
std
::
move
(
res
)));
argument_
=
argument
;
return
true
;
}
void
Run
(
DataFlowGraph
*
x
)
override
;
std
::
string
repr
()
const
override
{
return
"DFG-store-pass"
;
}
std
::
string
description
()
const
override
{
return
R"DD(This file defines ModelStorePass, which store the runtime DFG to a Paddle
model in the disk, and that model can be reloaded for prediction again.)DD"
;
}
private:
Argument
*
argument_
{
nullptr
};
};
return
res
;
}
// namespace analysis
}
}
// namespace inference
}
// namespace details
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/inference/analysis/model_store_pass_tester.cc
0 → 100644
浏览文件 @
a5c96af3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
namespace
paddle
{
namespace
inference
{
namespace
analysis
{
DEFINE_string
(
inference_model_dir
,
""
,
"Model path"
);
TEST
(
DFG_StorePass
,
test
)
{
Analyzer
analyzer
;
Argument
argument
(
FLAGS_inference_model_dir
);
argument
.
model_output_store_path
.
reset
(
new
std
::
string
(
"./_dfg_store_pass_tmp"
));
// disable storage in alalyzer
FLAGS_inference_analysis_output_storage_path
=
""
;
analyzer
.
Run
(
&
argument
);
ModelStorePass
pass
;
pass
.
Initialize
(
&
argument
);
pass
.
Run
(
argument
.
main_dfg
.
get
());
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/pass.h
浏览文件 @
a5c96af3
...
@@ -50,6 +50,7 @@ class Pass {
...
@@ -50,6 +50,7 @@ class Pass {
// Create a debugger Pass that draw the DFG by graphviz toolkit.
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual
Pass
*
CreateGraphvizDebugerPass
()
const
{
return
nullptr
;
}
virtual
Pass
*
CreateGraphvizDebugerPass
()
const
{
return
nullptr
;
}
virtual
void
Run
()
{
LOG
(
FATAL
)
<<
"not valid"
;
}
// Run on a single Node.
// Run on a single Node.
virtual
void
Run
(
Node
*
x
)
{
LOG
(
FATAL
)
<<
"not valid"
;
}
virtual
void
Run
(
Node
*
x
)
{
LOG
(
FATAL
)
<<
"not valid"
;
}
// Run on a single Function.
// Run on a single Function.
...
...
paddle/fluid/inference/analysis/pass_manager_tester.cc
浏览文件 @
a5c96af3
...
@@ -56,7 +56,7 @@ class TestNodePass final : public NodePass {
...
@@ -56,7 +56,7 @@ class TestNodePass final : public NodePass {
std
::
string
description
()
const
override
{
return
"some doc"
;
}
std
::
string
description
()
const
override
{
return
"some doc"
;
}
};
};
TEST
_F
(
DFG_Test
er
,
DFG_pass_manager
)
{
TEST
(
PassManag
er
,
DFG_pass_manager
)
{
TestDfgPassManager
manager
;
TestDfgPassManager
manager
;
DFG_GraphvizDrawPass
::
Config
config
(
"./"
,
"dfg.dot"
);
DFG_GraphvizDrawPass
::
Config
config
(
"./"
,
"dfg.dot"
);
...
@@ -64,12 +64,15 @@ TEST_F(DFG_Tester, DFG_pass_manager) {
...
@@ -64,12 +64,15 @@ TEST_F(DFG_Tester, DFG_pass_manager) {
manager
.
Register
(
"graphviz"
,
new
DFG_GraphvizDrawPass
(
config
));
manager
.
Register
(
"graphviz"
,
new
DFG_GraphvizDrawPass
(
config
));
manager
.
Register
(
"dfg-to-fluid"
,
new
DataFlowGraphToFluidPass
);
manager
.
Register
(
"dfg-to-fluid"
,
new
DataFlowGraphToFluidPass
);
Argument
argument
(
FLAGS_inference_model_dir
);
ASSERT_TRUE
(
&
argument
);
ASSERT_TRUE
(
&
argument
);
ASSERT_TRUE
(
manager
.
Initialize
(
&
argument
));
ASSERT_TRUE
(
manager
.
Initialize
(
&
argument
));
manager
.
RunAll
();
manager
.
RunAll
();
}
}
TEST_F
(
DFG_Tester
,
Node_pass_manager
)
{
TEST
(
PassManager
,
Node_pass_manager
)
{
Argument
argument
(
FLAGS_inference_model_dir
);
// Pre-process: initialize the DFG with the ProgramDesc first.
// Pre-process: initialize the DFG with the ProgramDesc first.
FluidToDataFlowGraphPass
pass0
;
FluidToDataFlowGraphPass
pass0
;
pass0
.
Initialize
(
&
argument
);
pass0
.
Initialize
(
&
argument
);
...
...
paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
浏览文件 @
a5c96af3
...
@@ -31,8 +31,8 @@ SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
...
@@ -31,8 +31,8 @@ SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
return
false
;
return
false
;
};
};
TEST
_F
(
DFG_Tes
ter
,
Split
)
{
TEST
(
SubGraphSplit
ter
,
Split
)
{
auto
desc
=
LoadProgramDesc
();
auto
desc
=
LoadProgramDesc
(
FLAGS_inference_model_dir
+
"/__model__"
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
LOG
(
INFO
)
<<
"spliter
\n
"
<<
dfg
.
DotString
();
LOG
(
INFO
)
<<
"spliter
\n
"
<<
dfg
.
DotString
();
...
@@ -63,8 +63,8 @@ TEST_F(DFG_Tester, Split) {
...
@@ -63,8 +63,8 @@ TEST_F(DFG_Tester, Split) {
ASSERT_EQ
(
subgraphs
.
back
().
size
(),
6UL
);
ASSERT_EQ
(
subgraphs
.
back
().
size
(),
6UL
);
}
}
TEST
_F
(
DFG_Tes
ter
,
Fuse
)
{
TEST
(
SubGraphSplit
ter
,
Fuse
)
{
auto
desc
=
LoadProgramDesc
();
auto
desc
=
LoadProgramDesc
(
FLAGS_inference_model_dir
+
"/__model__"
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
auto
dfg
=
ProgramDescToDFG
(
desc
);
size_t
count0
=
dfg
.
nodes
.
size
();
size_t
count0
=
dfg
.
nodes
.
size
();
...
...
paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass_tester.cc
浏览文件 @
a5c96af3
...
@@ -22,11 +22,11 @@ namespace paddle {
...
@@ -22,11 +22,11 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
namespace
analysis
{
namespace
analysis
{
TEST
_F
(
DFG_Tester
,
tensorrt_subgraph_node_mark_pass
)
{
TEST
(
TensorRTSubgraphNodeMarkPass
,
test
)
{
// init
// init
FluidToDataFlowGraphPass
pass
;
FluidToDataFlowGraphPass
pass
;
Argument
argument
(
FLAGS_inference_model_dir
);
ASSERT_TRUE
(
pass
.
Initialize
(
&
argument
));
ASSERT_TRUE
(
pass
.
Initialize
(
&
argument
));
argument
.
main_dfg
.
reset
(
new
DataFlowGraph
);
pass
.
Run
(
argument
.
main_dfg
.
get
());
pass
.
Run
(
argument
.
main_dfg
.
get
());
TensorRTSubgraphNodeMarkPass
::
teller_t
teller
=
[](
const
Node
*
node
)
{
TensorRTSubgraphNodeMarkPass
::
teller_t
teller
=
[](
const
Node
*
node
)
{
...
@@ -41,7 +41,7 @@ TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
...
@@ -41,7 +41,7 @@ TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
for
(
auto
&
node
:
argument
.
main_dfg
->
nodes
.
nodes
())
{
for
(
auto
&
node
:
argument
.
main_dfg
->
nodes
.
nodes
())
{
counter
+=
node
->
attr
(
ATTR_supported_by_tensorrt
).
Bool
();
counter
+=
node
->
attr
(
ATTR_supported_by_tensorrt
).
Bool
();
}
}
ASSERT_EQ
(
counter
,
2
);
LOG
(
INFO
)
<<
counter
<<
" nodes marked"
;
LOG
(
INFO
)
<<
counter
<<
" nodes marked"
;
}
}
...
...
paddle/fluid/inference/analysis/tensorrt_subgraph_pass_tester.cc
浏览文件 @
a5c96af3
...
@@ -25,7 +25,7 @@ namespace analysis {
...
@@ -25,7 +25,7 @@ namespace analysis {
DEFINE_string
(
dot_dir
,
"./"
,
""
);
DEFINE_string
(
dot_dir
,
"./"
,
""
);
TEST
_F
(
DFG_Tester
,
tensorrt_single_pass
)
{
TEST
(
TensorRTSubGraphPass
,
main
)
{
std
::
unordered_set
<
std
::
string
>
teller_set
(
std
::
unordered_set
<
std
::
string
>
teller_set
(
{
"elementwise_add"
,
"mul"
,
"sigmoid"
});
{
"elementwise_add"
,
"mul"
,
"sigmoid"
});
SubGraphSplitter
::
NodeInsideSubgraphTeller
teller
=
[
&
](
const
Node
*
node
)
{
SubGraphSplitter
::
NodeInsideSubgraphTeller
teller
=
[
&
](
const
Node
*
node
)
{
...
@@ -35,7 +35,8 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
...
@@ -35,7 +35,8 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
return
false
;
return
false
;
};
};
LOG
(
INFO
)
<<
"init"
;
Argument
argument
(
FLAGS_inference_model_dir
);
DFG_GraphvizDrawPass
::
Config
config
{
FLAGS_dot_dir
,
"origin"
};
DFG_GraphvizDrawPass
::
Config
config
{
FLAGS_dot_dir
,
"origin"
};
DFG_GraphvizDrawPass
::
Config
config1
{
FLAGS_dot_dir
,
"fusion"
};
DFG_GraphvizDrawPass
::
Config
config1
{
FLAGS_dot_dir
,
"fusion"
};
...
@@ -44,13 +45,11 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
...
@@ -44,13 +45,11 @@ TEST_F(DFG_Tester, tensorrt_single_pass) {
FluidToDataFlowGraphPass
pass0
;
FluidToDataFlowGraphPass
pass0
;
TensorRTSubGraphPass
trt_pass
(
std
::
move
(
teller
));
TensorRTSubGraphPass
trt_pass
(
std
::
move
(
teller
));
LOG
(
INFO
)
<<
"Initialize"
;
dfg_pass
.
Initialize
(
&
argument
);
dfg_pass
.
Initialize
(
&
argument
);
dfg_pass1
.
Initialize
(
&
argument
);
dfg_pass1
.
Initialize
(
&
argument
);
pass0
.
Initialize
(
&
argument
);
pass0
.
Initialize
(
&
argument
);
trt_pass
.
Initialize
(
&
argument
);
trt_pass
.
Initialize
(
&
argument
);
LOG
(
INFO
)
<<
"Run"
;
argument
.
main_dfg
.
reset
(
new
DataFlowGraph
);
argument
.
main_dfg
.
reset
(
new
DataFlowGraph
);
pass0
.
Run
(
argument
.
main_dfg
.
get
());
pass0
.
Run
(
argument
.
main_dfg
.
get
());
dfg_pass
.
Run
(
argument
.
main_dfg
.
get
());
dfg_pass
.
Run
(
argument
.
main_dfg
.
get
());
...
...
paddle/fluid/inference/analysis/ut_helper.h
浏览文件 @
a5c96af3
...
@@ -20,7 +20,7 @@ limitations under the License. */
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/
ut_
helper.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -32,27 +32,12 @@ namespace analysis {
...
@@ -32,27 +32,12 @@ namespace analysis {
DEFINE_string
(
inference_model_dir
,
""
,
"inference test model dir"
);
DEFINE_string
(
inference_model_dir
,
""
,
"inference test model dir"
);
static
framework
::
proto
::
ProgramDesc
LoadProgramDesc
(
const
std
::
string
&
model_dir
=
FLAGS_inference_model_dir
)
{
std
::
string
msg
;
std
::
string
net_file
=
FLAGS_inference_model_dir
+
"/__model__"
;
std
::
ifstream
fin
(
net_file
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
net_file
);
fin
.
seekg
(
0
,
std
::
ios
::
end
);
msg
.
resize
(
fin
.
tellg
());
fin
.
seekg
(
0
,
std
::
ios
::
beg
);
fin
.
read
(
&
(
msg
.
at
(
0
)),
msg
.
size
());
fin
.
close
();
framework
::
proto
::
ProgramDesc
program_desc
;
program_desc
.
ParseFromString
(
msg
);
return
program_desc
;
}
static
DataFlowGraph
ProgramDescToDFG
(
static
DataFlowGraph
ProgramDescToDFG
(
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
DataFlowGraph
graph
;
DataFlowGraph
graph
;
FluidToDataFlowGraphPass
pass
;
FluidToDataFlowGraphPass
pass
;
Argument
argument
;
Argument
argument
;
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
FLAGS_inference_model_dir
));
argument
.
origin_program_desc
.
reset
(
new
framework
::
proto
::
ProgramDesc
(
desc
));
argument
.
origin_program_desc
.
reset
(
new
framework
::
proto
::
ProgramDesc
(
desc
));
pass
.
Initialize
(
&
argument
);
pass
.
Initialize
(
&
argument
);
pass
.
Run
(
&
graph
);
pass
.
Run
(
&
graph
);
...
@@ -63,7 +48,7 @@ static DataFlowGraph ProgramDescToDFG(
...
@@ -63,7 +48,7 @@ static DataFlowGraph ProgramDescToDFG(
class
DFG_Tester
:
public
::
testing
::
Test
{
class
DFG_Tester
:
public
::
testing
::
Test
{
protected:
protected:
void
SetUp
()
override
{
void
SetUp
()
override
{
auto
desc
=
LoadProgramDesc
(
FLAGS_inference_model_dir
);
auto
desc
=
LoadProgramDesc
(
FLAGS_inference_model_dir
+
"/__model__"
);
argument
.
origin_program_desc
.
reset
(
new
framework
::
proto
::
ProgramDesc
(
desc
));
argument
.
origin_program_desc
.
reset
(
new
framework
::
proto
::
ProgramDesc
(
desc
));
}
}
...
...
paddle/fluid/inference/api/api_anakin_engine_tester.cc
浏览文件 @
a5c96af3
...
@@ -37,19 +37,21 @@ TEST(inference, anakin) {
...
@@ -37,19 +37,21 @@ TEST(inference, anakin) {
float
data
[
1
*
3
*
224
*
224
]
=
{
1.0
f
};
float
data
[
1
*
3
*
224
*
224
]
=
{
1.0
f
};
PaddleTensor
tensor
{.
name
=
"input_0"
,
PaddleTensor
tensor
;
.
shape
=
std
::
vector
<
int
>
({
1
,
3
,
224
,
224
}),
tensor
.
name
=
"input_0"
;
.
data
=
PaddleBuf
(
data
,
sizeof
(
data
)),
tensor
.
shape
=
std
::
vector
<
int
>
({
1
,
3
,
224
,
224
});
.
dtype
=
PaddleDType
::
FLOAT32
};
tensor
.
data
=
PaddleBuf
(
data
,
sizeof
(
data
));
tensor
.
dtype
=
PaddleDType
::
FLOAT32
;
// For simplicity, we set all the slots with the same data.
// For simplicity, we set all the slots with the same data.
std
::
vector
<
PaddleTensor
>
paddle_tensor_feeds
;
std
::
vector
<
PaddleTensor
>
paddle_tensor_feeds
;
paddle_tensor_feeds
.
emplace_back
(
std
::
move
(
tensor
));
paddle_tensor_feeds
.
emplace_back
(
std
::
move
(
tensor
));
PaddleTensor
tensor_out
{.
name
=
"prob_out"
,
PaddleTensor
tensor_out
;
.
shape
=
std
::
vector
<
int
>
({
1000
,
1
}),
tensor_out
.
name
=
"prob_out"
;
.
data
=
PaddleBuf
(),
tensor_out
.
shape
=
std
::
vector
<
int
>
({
1000
,
1
});
.
dtype
=
PaddleDType
::
FLOAT32
};
tensor_out
.
data
=
PaddleBuf
();
tensor_out
.
dtype
=
PaddleDType
::
FLOAT32
;
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
;
outputs
.
emplace_back
(
std
::
move
(
tensor_out
));
outputs
.
emplace_back
(
std
::
move
(
tensor_out
));
...
...
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
a5c96af3
...
@@ -183,6 +183,13 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
...
@@ -183,6 +183,13 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
inputs
[
i
].
data
.
data
(),
std
::
memcpy
(
static_cast
<
void
*>
(
input_ptr
),
inputs
[
i
].
data
.
data
(),
inputs
[
i
].
data
.
length
());
inputs
[
i
].
data
.
length
());
// TODO(Superjomn) Low performance, need optimization for heavy LoD copy.
framework
::
LoD
lod
;
for
(
auto
&
level
:
inputs
[
i
].
lod
)
{
lod
.
emplace_back
(
level
);
}
input
.
set_lod
(
lod
);
feeds
->
push_back
(
input
);
feeds
->
push_back
(
input
);
}
}
return
true
;
return
true
;
...
@@ -248,6 +255,10 @@ bool NativePaddlePredictor::GetFetch(
...
@@ -248,6 +255,10 @@ bool NativePaddlePredictor::GetFetch(
buffer
.
Resize
(
sizeof
(
float
)
*
data
.
size
());
buffer
.
Resize
(
sizeof
(
float
)
*
data
.
size
());
}
}
std
::
memcpy
(
buffer
.
data
(),
data
.
data
(),
buffer
.
length
());
std
::
memcpy
(
buffer
.
data
(),
data
.
data
(),
buffer
.
length
());
// copy LoD
for
(
const
auto
&
level
:
fetchs
[
i
].
lod
())
{
outputs
->
at
(
i
).
lod
.
emplace_back
(
level
);
}
outputs
->
at
(
i
).
dtype
=
PaddleDType
::
FLOAT32
;
outputs
->
at
(
i
).
dtype
=
PaddleDType
::
FLOAT32
;
// TODO(panyx0718): support other types? fill tensor name? avoid a copy.
// TODO(panyx0718): support other types? fill tensor name? avoid a copy.
}
}
...
...
paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc
浏览文件 @
a5c96af3
...
@@ -90,6 +90,18 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
...
@@ -90,6 +90,18 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
void
OptimizeInferenceProgram
()
{
void
OptimizeInferenceProgram
()
{
// Analyze inference_program
// Analyze inference_program
Argument
argument
;
Argument
argument
;
if
(
!
config_
.
model_dir
.
empty
())
{
argument
.
fluid_model_dir
.
reset
(
new
std
::
string
(
config_
.
model_dir
));
}
else
{
PADDLE_ENFORCE
(
!
config_
.
param_file
.
empty
(),
"Either model_dir or (param_file, prog_file) should be set."
);
PADDLE_ENFORCE
(
!
config_
.
prog_file
.
empty
());
argument
.
fluid_model_program_path
.
reset
(
new
std
::
string
(
config_
.
prog_file
));
argument
.
fluid_model_param_path
.
reset
(
new
std
::
string
(
config_
.
param_file
));
}
argument
.
origin_program_desc
.
reset
(
argument
.
origin_program_desc
.
reset
(
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
new
ProgramDesc
(
*
inference_program_
->
Proto
()));
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
Singleton
<
Analyzer
>::
Global
().
Run
(
&
argument
);
...
...
paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc
浏览文件 @
a5c96af3
...
@@ -49,11 +49,10 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
...
@@ -49,11 +49,10 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
std
::
vector
<
int64_t
>
data
(
20
);
std
::
vector
<
int64_t
>
data
(
20
);
for
(
int
i
=
0
;
i
<
20
;
i
++
)
data
[
i
]
=
i
;
for
(
int
i
=
0
;
i
<
20
;
i
++
)
data
[
i
]
=
i
;
PaddleTensor
tensor
{
PaddleTensor
tensor
;
.
name
=
""
,
tensor
.
shape
=
std
::
vector
<
int
>
({
10
,
1
});
.
shape
=
std
::
vector
<
int
>
({
10
,
1
}),
tensor
.
data
=
PaddleBuf
(
data
.
data
(),
data
.
size
()
*
sizeof
(
int64_t
));
.
data
=
PaddleBuf
(
data
.
data
(),
data
.
size
()
*
sizeof
(
int64_t
)),
tensor
.
dtype
=
PaddleDType
::
INT64
;
.
dtype
=
PaddleDType
::
INT64
};
// For simplicity, we set all the slots with the same data.
// For simplicity, we set all the slots with the same data.
std
::
vector
<
PaddleTensor
>
slots
(
4
,
tensor
);
std
::
vector
<
PaddleTensor
>
slots
(
4
,
tensor
);
...
...
paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc
浏览文件 @
a5c96af3
...
@@ -47,10 +47,10 @@ void Main(bool use_gpu) {
...
@@ -47,10 +47,10 @@ void Main(bool use_gpu) {
//# 2. Prepare input.
//# 2. Prepare input.
int64_t
data
[
4
]
=
{
1
,
2
,
3
,
4
};
int64_t
data
[
4
]
=
{
1
,
2
,
3
,
4
};
PaddleTensor
tensor
{.
name
=
""
,
PaddleTensor
tensor
;
.
shape
=
std
::
vector
<
int
>
({
4
,
1
}),
tensor
.
shape
=
std
::
vector
<
int
>
({
4
,
1
});
.
data
=
PaddleBuf
(
data
,
sizeof
(
data
)),
tensor
.
data
=
PaddleBuf
(
data
,
sizeof
(
data
));
.
dtype
=
PaddleDType
::
INT64
}
;
tensor
.
dtype
=
PaddleDType
::
INT64
;
// For simplicity, we set all the slots with the same data.
// For simplicity, we set all the slots with the same data.
std
::
vector
<
PaddleTensor
>
slots
(
4
,
tensor
);
std
::
vector
<
PaddleTensor
>
slots
(
4
,
tensor
);
...
@@ -94,10 +94,11 @@ void MainThreads(int num_threads, bool use_gpu) {
...
@@ -94,10 +94,11 @@ void MainThreads(int num_threads, bool use_gpu) {
for
(
int
batch_id
=
0
;
batch_id
<
num_batches
;
++
batch_id
)
{
for
(
int
batch_id
=
0
;
batch_id
<
num_batches
;
++
batch_id
)
{
// 2. Dummy Input Data
// 2. Dummy Input Data
int64_t
data
[
4
]
=
{
1
,
2
,
3
,
4
};
int64_t
data
[
4
]
=
{
1
,
2
,
3
,
4
};
PaddleTensor
tensor
{.
name
=
""
,
PaddleTensor
tensor
;
.
shape
=
std
::
vector
<
int
>
({
4
,
1
}),
tensor
.
shape
=
std
::
vector
<
int
>
({
4
,
1
});
.
data
=
PaddleBuf
(
data
,
sizeof
(
data
)),
tensor
.
data
=
PaddleBuf
(
data
,
sizeof
(
data
));
.
dtype
=
PaddleDType
::
INT64
};
tensor
.
dtype
=
PaddleDType
::
INT64
;
std
::
vector
<
PaddleTensor
>
inputs
(
4
,
tensor
);
std
::
vector
<
PaddleTensor
>
inputs
(
4
,
tensor
);
std
::
vector
<
PaddleTensor
>
outputs
;
std
::
vector
<
PaddleTensor
>
outputs
;
// 3. Run
// 3. Run
...
...
paddle/fluid/inference/api/demo_ci/vis_demo.cc
浏览文件 @
a5c96af3
...
@@ -123,11 +123,11 @@ void Main(bool use_gpu) {
...
@@ -123,11 +123,11 @@ void Main(bool use_gpu) {
file
.
close
();
file
.
close
();
// Inference.
// Inference.
PaddleTensor
input
{
PaddleTensor
input
;
.
name
=
"xx"
,
input
.
shape
=
record
.
shape
;
.
shape
=
record
.
shape
,
input
.
data
=
.
data
=
PaddleBuf
(
record
.
data
.
data
(),
record
.
data
.
size
()
*
sizeof
(
float
)),
PaddleBuf
(
record
.
data
.
data
(),
record
.
data
.
size
()
*
sizeof
(
float
));
.
dtype
=
PaddleDType
::
FLOAT32
}
;
input
.
dtype
=
PaddleDType
::
FLOAT32
;
VLOG
(
3
)
<<
"run executor"
;
VLOG
(
3
)
<<
"run executor"
;
std
::
vector
<
PaddleTensor
>
output
;
std
::
vector
<
PaddleTensor
>
output
;
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
a5c96af3
...
@@ -67,9 +67,9 @@ struct PaddleTensor {
...
@@ -67,9 +67,9 @@ struct PaddleTensor {
PaddleTensor
()
=
default
;
PaddleTensor
()
=
default
;
std
::
string
name
;
// variable name.
std
::
string
name
;
// variable name.
std
::
vector
<
int
>
shape
;
std
::
vector
<
int
>
shape
;
// TODO(Superjomn) for LoD support, add a vector<vector<int>> field if needed.
PaddleBuf
data
;
// blob of data.
PaddleBuf
data
;
// blob of data.
PaddleDType
dtype
;
PaddleDType
dtype
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
lod
;
// lod data
};
};
enum
class
PaddleEngineKind
{
enum
class
PaddleEngineKind
{
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
a5c96af3
...
@@ -19,12 +19,17 @@ limitations under the License. */
...
@@ -19,12 +19,17 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
listen_and_serv_profile_period
,
0
,
"the period of listen_and_serv to do profile"
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -122,7 +127,18 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -122,7 +127,18 @@ void ListenAndServOp::RunSyncLoop(
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
rpc_service_
->
ResetBarrierCounter
();
rpc_service_
->
ResetBarrierCounter
();
int32_t
profile_step
=
0
;
while
(
true
)
{
while
(
true
)
{
PADDLE_ENFORCE_LE
(
profile_step
,
FLAGS_listen_and_serv_profile_period
,
"profile_step should not be larger then "
"FLAGS_listen_and_serv_profile_period"
);
if
(
FLAGS_listen_and_serv_profile_period
>
0
)
{
if
(
profile_step
==
0
)
{
auto
pf_state
=
paddle
::
platform
::
ProfilerState
::
kCPU
;
paddle
::
platform
::
EnableProfiler
(
pf_state
);
}
}
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
...
@@ -164,6 +180,15 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -164,6 +180,15 @@ void ListenAndServOp::RunSyncLoop(
// reset received sparse vars to avoid reuse it in the next mini-batch
// reset received sparse vars to avoid reuse it in the next mini-batch
dynamic_cast
<
distributed
::
RequestSendHandler
*>
(
request_send_handler_
.
get
())
dynamic_cast
<
distributed
::
RequestSendHandler
*>
(
request_send_handler_
.
get
())
->
ResetSparseVarRecorder
();
->
ResetSparseVarRecorder
();
if
(
FLAGS_listen_and_serv_profile_period
>
0
)
{
if
(
profile_step
==
FLAGS_listen_and_serv_profile_period
)
{
paddle
::
platform
::
DisableProfiler
(
paddle
::
platform
::
EventSortingKey
::
kTotal
,
"/dev/null"
);
profile_step
=
0
;
}
else
{
profile_step
++
;
}
}
}
// while(true)
}
// while(true)
}
}
...
...
paddle/fluid/operators/math/im2col.cc
浏览文件 @
a5c96af3
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/im2col.h"
#include <vector>
#include <vector>
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -35,61 +36,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
...
@@ -35,61 +36,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
im
.
dims
().
size
()
==
3
);
PADDLE_ENFORCE
(
col
->
dims
().
size
()
==
5
);
PADDLE_ENFORCE
(
col
->
dims
().
size
()
==
5
);
int
im_channels
=
im
.
dims
()[
0
];
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_width
=
col
->
dims
()[
4
];
int
channels_col
=
im_channels
*
filter_height
*
filter_width
;
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
// TODO(TJ): change me to template
// further optimaze:
// 1. padding != 1
// 2. could also support stride_h != 1
if
(
stride
[
0
]
==
1
&&
stride
[
1
]
==
1
&&
dilation
[
0
]
==
1
&&
if
(
stride
[
0
]
==
1
&&
stride
[
1
]
==
1
&&
dilation
[
0
]
==
1
&&
dilation
[
1
]
==
1
&&
padding
[
0
]
==
0
&&
padding
[
1
]
==
0
)
{
dilation
[
1
]
==
1
)
{
int
col_matrix_width
=
output_width
*
output_height
;
if
(
padding
[
0
]
==
0
&&
padding
[
1
]
==
0
)
{
size_t
copy_size
=
sizeof
(
T
)
*
output_width
;
im2col_sh1sw1dh1dw1ph0pw0
<
T
>
(
im
,
col
);
for
(
int
oh
=
0
;
oh
<
output_height
;
++
oh
)
{
return
;
const
T
*
im_data_start
=
im_data
+
oh
*
im_width
;
}
else
if
(
padding
[
0
]
==
1
&&
padding
[
1
]
==
1
)
{
T
*
dst_data
=
col_data
+
oh
*
output_width
;
im2col_sh1sw1dh1dw1ph1pw1
<
T
>
(
im
,
col
);
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
const
T
*
src_data
=
im_data_start
+
ic
*
im_height
*
im_width
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
std
::
memcpy
(
dst_data
,
src_data
+
kw
,
copy_size
);
dst_data
=
dst_data
+
col_matrix_width
;
}
src_data
=
src_data
+
im_width
;
}
}
}
return
;
return
;
}
}
// TODO(TJ): complete padding >=2
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
(
filter_width
*
filter_height
);
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
col_idx
=
(
c
*
output_height
+
h
)
*
output_width
+
w
;
int
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
?
static_cast
<
T
>
(
0
)
:
im_data
[
im_idx
];
}
}
}
}
im2col_common
<
T
>
(
im
,
dilation
,
stride
,
padding
,
col
);
}
}
};
};
...
...
paddle/fluid/operators/math/im2col_cfo_cpu.h
0 → 100644
浏览文件 @
a5c96af3
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/**
* The most common im2col algorithm.
* Support dilation, stride and padding.
*/
template
<
typename
T
>
inline
void
im2col_common
(
const
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
framework
::
Tensor
*
col
)
{
int
im_channels
=
im
.
dims
()[
0
];
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_width
=
col
->
dims
()[
4
];
int
channels_col
=
im_channels
*
filter_height
*
filter_width
;
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
(
filter_width
*
filter_height
);
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
col_idx
=
(
c
*
output_height
+
h
)
*
output_width
+
w
;
int
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
?
static_cast
<
T
>
(
0
)
:
im_data
[
im_idx
];
}
}
}
}
/**
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
*/
template
<
typename
T
>
inline
void
im2col_sh1sw1dh1dw1ph0pw0
(
const
framework
::
Tensor
&
im
,
framework
::
Tensor
*
col
)
{
int
im_channels
=
im
.
dims
()[
0
];
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_width
=
col
->
dims
()[
4
];
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
int
col_matrix_width
=
output_width
*
output_height
;
int
im_size
=
im_height
*
im_width
;
size_t
copy_size
=
sizeof
(
T
)
*
output_width
;
const
T
*
im_data_oh
=
im_data
;
T
*
dst_data_oh
=
col_data
;
for
(
int
oh
=
0
;
oh
<
output_height
;
++
oh
)
{
const
T
*
src_data_ic
=
im_data_oh
;
T
*
dst_data
=
dst_data_oh
;
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
const
T
*
src_data
=
src_data_ic
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
std
::
memcpy
(
dst_data
,
src_data
+
kw
,
copy_size
);
dst_data
=
dst_data
+
col_matrix_width
;
}
src_data
=
src_data
+
im_width
;
}
src_data_ic
=
src_data_ic
+
im_size
;
}
im_data_oh
=
im_data_oh
+
im_width
;
dst_data_oh
=
dst_data_oh
+
output_width
;
}
}
/**
* im2col algorithm with strides == 1, dilations == 1, paddings == 1
* and filter_width == 1 have a special implementation
*/
template
<
typename
T
>
inline
void
im2col_sh1sw1dh1dw1ph1pw1
(
const
framework
::
Tensor
&
im
,
framework
::
Tensor
*
col
)
{
int
im_channels
=
im
.
dims
()[
0
];
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_width
=
col
->
dims
()[
4
];
constexpr
int
plh
=
1
;
constexpr
int
prh
=
1
;
constexpr
int
plw
=
1
;
constexpr
int
prw
=
1
;
const
T
*
im_data
=
im
.
data
<
T
>
();
T
*
col_data
=
col
->
data
<
T
>
();
int
im_size
=
im_height
*
im_width
;
int
col_matrix_width
=
output_width
*
output_height
;
int
col_block_fh
=
filter_width
*
col_matrix_width
;
// fw*oh*ow
int
col_block_ic
=
filter_height
*
col_block_fh
;
// fh*fw*oh*ow
// fill height padding
{
size_t
copy_size
=
sizeof
(
T
)
*
output_width
;
T
*
col_start_l
=
col_data
;
T
*
col_start_r
=
col_data
+
(
filter_height
-
1
)
*
col_block_fh
+
col_matrix_width
-
output_width
;
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
T
*
dst_data_l
=
col_start_l
;
T
*
dst_data_r
=
col_start_r
;
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
std
::
memset
(
dst_data_l
,
0
,
copy_size
);
std
::
memset
(
dst_data_r
,
0
,
copy_size
);
dst_data_l
=
dst_data_l
+
col_matrix_width
;
dst_data_r
=
dst_data_r
+
col_matrix_width
;
}
col_start_l
=
col_start_l
+
col_block_ic
;
col_start_r
=
col_start_r
+
col_block_ic
;
}
}
auto
pad
=
static_cast
<
T
>
(
0
);
if
(
filter_width
==
1
)
{
// fill width padding
T
*
dst_data_ic
=
col_data
;
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
T
*
dst_data_kh
=
dst_data_ic
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
T
*
dst_data
=
dst_data_kh
;
for
(
int
oh
=
0
;
oh
<
output_height
;
++
oh
)
{
*
dst_data
=
pad
;
dst_data
=
dst_data
+
output_width
-
1
;
*
dst_data
=
pad
;
++
dst_data
;
}
dst_data_kh
=
dst_data_kh
+
col_block_fh
;
}
dst_data_ic
=
dst_data_ic
+
col_block_ic
;
}
// fill core
size_t
copy_size
=
sizeof
(
T
)
*
(
output_width
-
plw
-
prw
);
for
(
int
oh
=
0
;
oh
<
output_height
;
++
oh
)
{
const
T
*
im_data_start
=
im_data
+
(
oh
-
plh
>
0
?
oh
-
plh
:
0
)
*
im_width
;
T
*
dst_data
=
col_data
+
oh
*
output_width
;
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
const
T
*
src_data
=
im_data_start
+
ic
*
im_size
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
if
((
oh
<
plh
&&
kh
<
plh
)
||
(
oh
>
(
output_height
-
prh
-
1
)
&&
kh
>
(
filter_height
-
prh
-
1
)))
{
dst_data
=
dst_data
+
col_matrix_width
;
continue
;
}
std
::
memcpy
(
dst_data
+
plw
,
src_data
,
copy_size
);
dst_data
=
dst_data
+
col_matrix_width
;
src_data
=
src_data
+
im_width
;
}
}
}
return
;
}
// filter_width != 1
// fill width padding
T
*
dst_data_ic
=
col_data
;
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
T
*
dst_data_kh
=
dst_data_ic
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
T
*
dst_data
:
{
dst_data_kh
,
dst_data_kh
+
(
filter_width
-
prw
)
*
col_matrix_width
+
output_width
-
1
})
{
// TODO(TJ): from plh, saving repeated assignment
for
(
int
oh
=
0
;
oh
<
output_height
;
++
oh
)
{
*
dst_data
=
pad
;
dst_data
=
dst_data
+
output_width
;
}
}
dst_data_kh
=
dst_data_kh
+
col_block_fh
;
}
dst_data_ic
=
dst_data_ic
+
col_block_ic
;
}
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
// (output_width-1)}
// length of copy_size is equal kw.
for
(
int
oh
=
0
;
oh
<
output_height
;
++
oh
)
{
const
T
*
im_data_start
=
im_data
+
(
oh
-
plh
>
0
?
oh
-
plh
:
0
)
*
im_width
;
T
*
dst_data
=
col_data
+
oh
*
output_width
;
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
const
T
*
src_data
=
im_data_start
+
ic
*
im_size
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
if
((
oh
<
plh
&&
kh
<
plh
)
||
(
oh
>
(
output_height
-
prh
-
1
)
&&
kh
>
(
filter_height
-
prh
-
1
)))
{
dst_data
=
dst_data
+
filter_width
*
col_matrix_width
;
continue
;
}
// TODO(TJ): reuse plw-kw outside this for
// try to unify
for
(
int
kw
=
0
;
kw
<
plw
;
++
kw
)
{
std
::
memcpy
(
dst_data
+
(
plw
-
kw
),
src_data
,
sizeof
(
T
)
*
(
output_width
-
(
plw
-
kw
)));
dst_data
=
dst_data
+
col_matrix_width
;
}
for
(
int
kw
=
plw
;
kw
<
filter_width
-
prw
;
++
kw
)
{
std
::
memcpy
(
dst_data
,
src_data
+
(
kw
-
plw
),
sizeof
(
T
)
*
output_width
);
dst_data
=
dst_data
+
col_matrix_width
;
}
int
i
=
1
;
for
(
int
kw
=
filter_width
-
prw
;
kw
<
filter_width
;
++
kw
,
++
i
)
{
std
::
memcpy
(
dst_data
,
src_data
+
(
kw
-
plw
),
sizeof
(
T
)
*
(
output_width
-
i
));
dst_data
=
dst_data
+
col_matrix_width
;
}
src_data
=
src_data
+
im_width
;
}
}
}
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/im2col_test.cc
浏览文件 @
a5c96af3
...
@@ -14,7 +14,9 @@ limitations under the License. */
...
@@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/im2col.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <sys/time.h>
#include <vector>
#include <vector>
#include "paddle/fluid/operators/math/im2col_cfo_cpu.h"
template
<
typename
DeviceContext
,
typename
Place
>
template
<
typename
DeviceContext
,
typename
Place
>
void
testIm2col
()
{
void
testIm2col
()
{
...
@@ -160,82 +162,111 @@ void testIm2col() {
...
@@ -160,82 +162,111 @@ void testIm2col() {
delete
context
;
delete
context
;
}
}
TEST
(
math
,
im2col
)
{
testIm2col
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
CPUPlace
>
();
#ifdef PADDLE_WITH_CUDA
testIm2col
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
CUDAPlace
>
();
#endif
}
#define PREPARE_IM2COL_CPU \
paddle::platform::CPUPlace place; \
paddle::platform::CPUDeviceContext context(place); \
paddle::framework::Tensor input; \
paddle::framework::Tensor out; \
paddle::framework::Tensor ref; \
std::vector<int> padding({ph, pw}); \
std::vector<int> stride({1, 1}); \
std::vector<int> dilation({1, 1}); \
float* input_ptr = input.mutable_data<float>({ic, ih, iw}, place); \
for (int i = 0; i < input.numel(); ++i) { \
input_ptr[i] = static_cast<float>(i + 1); \
} \
int output_height = (ih - fh + padding[0] * 2) / stride[0] + 1; \
int output_width = (iw - fw + padding[1] * 2) / stride[1] + 1; \
out.mutable_data<float>({ic, fh, fw, output_height, output_width}, place); \
ref.mutable_data<float>({ic, fh, fw, output_height, output_width}, place); \
paddle::operators::math::Im2ColFunctor< \
paddle::operators::math::ColFormat::kCFO, \
paddle::platform::CPUDeviceContext, float> \
im2col
void
testIm2colCPU
(
int
ic
,
int
ih
,
int
iw
,
int
fh
,
int
fw
,
int
ph
,
int
pw
)
{
void
testIm2colCPU
(
int
ic
,
int
ih
,
int
iw
,
int
fh
,
int
fw
,
int
ph
,
int
pw
)
{
paddle
::
framework
::
Tensor
input
;
PREPARE_IM2COL_CPU
;
paddle
::
framework
::
Tensor
output
;
paddle
::
framework
::
Tensor
ref_output
;
im2col
(
context
,
input
,
dilation
,
stride
,
padding
,
&
out
);
std
::
vector
<
int
>
padding
({
ph
,
pw
});
paddle
::
operators
::
math
::
im2col_common
<
float
>
(
input
,
dilation
,
stride
,
std
::
vector
<
int
>
stride
({
1
,
1
});
// stride_y, stride_x
padding
,
&
ref
);
std
::
vector
<
int
>
dilation
({
1
,
1
});
// dilation_y, dilation_x
int
output_height
=
(
ih
-
fh
+
padding
[
0
]
*
2
)
/
stride
[
0
]
+
1
;
float
*
ref_data
=
ref
.
data
<
float
>
();
int
output_width
=
(
iw
-
fw
+
padding
[
1
]
*
2
)
/
stride
[
1
]
+
1
;
float
*
out_data
=
out
.
data
<
float
>
();
float
*
input_ptr
=
for
(
int
i
=
0
;
i
<
out
.
numel
();
++
i
)
{
input
.
mutable_data
<
float
>
({
ic
,
ih
,
iw
},
paddle
::
platform
::
CPUPlace
());
EXPECT_EQ
(
out_data
[
i
],
ref_data
[
i
]);
for
(
int
i
=
0
;
i
<
input
.
numel
();
++
i
)
{
input_ptr
[
i
]
=
static_cast
<
float
>
(
i
+
1
);
}
paddle
::
platform
::
CPUPlace
place
;
paddle
::
platform
::
CPUDeviceContext
context
(
place
);
output
.
mutable_data
<
float
>
({
ic
,
fh
,
fw
,
output_height
,
output_width
},
place
);
ref_output
.
mutable_data
<
float
>
({
ic
,
fh
,
fw
,
output_height
,
output_width
},
place
);
paddle
::
operators
::
math
::
Im2ColFunctor
<
paddle
::
operators
::
math
::
ColFormat
::
kCFO
,
paddle
::
platform
::
CPUDeviceContext
,
float
>
im2col
;
im2col
(
context
,
input
,
dilation
,
stride
,
padding
,
&
output
);
auto
ref_im2col
=
[
&
](
const
paddle
::
framework
::
Tensor
&
im
,
const
std
::
vector
<
int
>&
dilation
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
padding
,
paddle
::
framework
::
Tensor
*
col
)
{
int
im_channels
=
im
.
dims
()[
0
];
int
im_height
=
im
.
dims
()[
1
];
int
im_width
=
im
.
dims
()[
2
];
int
filter_height
=
col
->
dims
()[
1
];
int
filter_width
=
col
->
dims
()[
2
];
int
output_height
=
col
->
dims
()[
3
];
int
output_width
=
col
->
dims
()[
4
];
int
channels_col
=
im_channels
*
filter_height
*
filter_width
;
const
float
*
im_data
=
im
.
data
<
float
>
();
float
*
col_data
=
col
->
data
<
float
>
();
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
(
filter_width
*
filter_height
);
for
(
int
h
=
0
;
h
<
output_height
;
++
h
)
{
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
for
(
int
w
=
0
;
w
<
output_width
;
++
w
)
{
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
col_idx
=
(
c
*
output_height
+
h
)
*
output_width
+
w
;
int
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
?
0.
f
:
im_data
[
im_idx
];
}
}
}
void
benchIm2col
(
int
ic
,
int
ih
,
int
iw
,
int
fh
,
int
fw
,
int
ph
,
int
pw
)
{
PREPARE_IM2COL_CPU
;
constexpr
int
repeat
=
100
;
auto
GetCurrentMs
=
[]()
->
double
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
return
1e+3
*
time
.
tv_sec
+
1e-3
*
time
.
tv_usec
;
};
auto
t1
=
GetCurrentMs
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
im2col
(
context
,
input
,
dilation
,
stride
,
padding
,
&
out
);
}
}
auto
t2
=
GetCurrentMs
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
paddle
::
operators
::
math
::
im2col_common
<
float
>
(
input
,
dilation
,
stride
,
padding
,
&
ref
);
}
}
}
;
auto
t3
=
GetCurrentMs
()
;
ref_im2col
(
input
,
dilation
,
stride
,
padding
,
&
ref_output
);
LOG
(
INFO
)
<<
"before: "
<<
(
t3
-
t2
)
/
repeat
<<
",after: "
<<
(
t2
-
t1
)
/
repeat
<<
",boost: "
<<
((
t3
-
t2
)
/
(
t2
-
t1
)
-
1
)
*
100
<<
"%"
;
}
float
*
out_cfo_ptr
=
output
.
data
<
float
>
();
TEST
(
math
,
im2col_cputest
)
{
float
*
out_ref_ptr
=
ref_output
.
data
<
float
>
();
// padding_h == padding_w
for
(
int
i
=
0
;
i
<
output
.
numel
();
++
i
)
{
for
(
int
p
=
0
;
p
<
4
;
++
p
)
{
EXPECT_EQ
(
out_cfo_ptr
[
i
],
out_ref_ptr
[
i
]);
// width == height
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
5
,
/*iw*/
5
,
/*fh*/
4
,
/*fw*/
4
,
/*ph*/
p
,
/*pw*/
p
);
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
4
,
/*iw*/
4
,
/*fh*/
3
,
/*fw*/
3
,
/*ph*/
p
,
/*pw*/
p
);
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
4
,
/*iw*/
4
,
/*fh*/
2
,
/*fw*/
2
,
/*ph*/
p
,
/*pw*/
p
);
// height != width
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
5
,
/*iw*/
4
,
/*fh*/
2
,
/*fw*/
3
,
/*ph*/
p
,
/*pw*/
p
);
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
5
,
/*iw*/
4
,
/*fh*/
1
,
/*fw*/
3
,
/*ph*/
p
,
/*pw*/
p
);
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
4
,
/*iw*/
5
,
/*fh*/
3
,
/*fw*/
1
,
/*ph*/
p
,
/*pw*/
p
);
// filter == 1
testIm2colCPU
(
/*ic*/
3
,
/*ih*/
4
,
/*iw*/
4
,
/*fh*/
1
,
/*fw*/
1
,
/*ph*/
p
,
/*pw*/
p
);
testIm2colCPU
(
/*ic*/
3
,
/*ih*/
3
,
/*iw*/
4
,
/*fh*/
1
,
/*fw*/
1
,
/*ph*/
p
,
/*pw*/
p
);
}
}
}
TEST
(
math
,
im2col
)
{
// padding_h != padding_w
testIm2col
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
CPUPlace
>
();
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
4
,
/*iw*/
4
,
/*fh*/
2
,
/*fw*/
3
,
/*ph*/
1
,
testIm2colCPU
(
/*ic*/
3
,
/*ih*/
5
,
/*iw*/
5
,
/*fh*/
3
,
/*fw*/
2
,
/*ph*/
0
,
/*pw*/
2
);
/*pw*/
0
);
testIm2colCPU
(
/*ic*/
2
,
/*ih*/
5
,
/*iw*/
4
,
/*fh*/
3
,
/*fw*/
3
,
/*ph*/
1
,
// benchmark
/*pw*/
1
);
for
(
int
p
:
{
0
,
1
})
{
#ifdef PADDLE_WITH_CUDA
for
(
int
k
:
{
1
,
3
,
5
})
{
testIm2col
<
paddle
::
platform
::
CUDADeviceContext
,
LOG
(
INFO
)
<<
"padding == "
<<
p
<<
", filter == "
<<
k
;
paddle
::
platform
::
CUDAPlace
>
();
benchIm2col
(
/*ic*/
3
,
/*ih*/
224
,
/*iw*/
224
,
/*fh*/
k
,
/*fw*/
k
,
#endif
/*ph*/
p
,
/*pw*/
p
);
}
}
}
}
paddle/fluid/operators/reshape_op.cc
浏览文件 @
a5c96af3
...
@@ -127,12 +127,6 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -127,12 +127,6 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"Out"
,
"(Tensor). The output tensor of reshape operator."
);
AddOutput
(
"Out"
,
"(Tensor). The output tensor of reshape operator."
);
AddAttr
<
std
::
vector
<
int
>>
(
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"(std::vector<int>) Target shape of reshape operator."
);
"shape"
,
"(std::vector<int>) Target shape of reshape operator."
);
AddAttr
<
bool
>
(
"inplace"
,
"(default: false) Change the source tensor's shape without "
"memory copy. When Attr(inplace) is set true, the output "
"tensor shares memory with Input(X), otherwise, a new output "
"tensor is created, and its data are copied from Input(x)."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Reshape Operator.
Reshape Operator.
...
@@ -233,16 +227,9 @@ class ReshapeKernel {
...
@@ -233,16 +227,9 @@ class ReshapeKernel {
"sequence_reshape op."
);
"sequence_reshape op."
);
}
}
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
out
->
mutable_data
(
ctx
.
GetPlace
(),
in
->
type
());
out
->
mutable_data
(
ctx
.
GetPlace
(),
in
->
type
());
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
out
->
Resize
(
out_dims
);
out
->
Resize
(
out_dims
);
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
Resize
(
out_dims
);
}
}
}
};
};
...
@@ -251,19 +238,11 @@ class ReshapeGradKernel {
...
@@ -251,19 +238,11 @@ class ReshapeGradKernel {
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
in_dims
=
d_x
->
dims
();
d_x
->
mutable_data
(
ctx
.
GetPlace
(),
d_out
->
type
());
d_x
->
mutable_data
(
ctx
.
GetPlace
(),
d_out
->
type
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
framework
::
TensorCopySync
(
*
d_out
,
ctx
.
GetPlace
(),
d_x
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
ctx
.
device_context
().
Wait
();
d_x
->
Resize
(
in_dims
);
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
}
};
};
...
...
paddle/fluid/operators/split_ids_op.h
浏览文件 @
a5c96af3
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
@@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
...
@@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
const
auto
&
ids_rows
=
ids_selected_rows
->
rows
();
const
auto
&
ids_rows
=
ids_selected_rows
->
rows
();
auto
outs
=
ctx
.
MultiOutput
<
framework
::
SelectedRows
>
(
"Out"
);
auto
outs
=
ctx
.
MultiOutput
<
framework
::
SelectedRows
>
(
"Out"
);
const
size_t
shard_num
=
outs
.
size
();
const
size_t
shard_num
=
outs
.
size
();
for
(
auto
&
out
:
outs
)
{
out
->
mutable_rows
()
->
clear
();
}
// get rows for outputs
// get rows for outputs
for
(
auto
&
id
:
ids_rows
)
{
std
::
unordered_map
<
int64_t
,
size_t
>
id_to_index
;
size_t
shard_id
=
static_cast
<
size_t
>
(
id
)
%
shard_num
;
for
(
size_t
i
=
0
;
i
<
ids_rows
.
size
();
++
i
)
{
outs
[
shard_id
]
->
mutable_rows
()
->
push_back
(
id
);
id_to_index
[
ids_rows
[
i
]]
=
i
;
size_t
shard_id
=
static_cast
<
size_t
>
(
ids_rows
[
i
])
%
shard_num
;
outs
[
shard_id
]
->
mutable_rows
()
->
push_back
(
ids_rows
[
i
]);
}
}
int64_t
row_width
=
ids_dims
[
1
];
int64_t
row_width
=
ids_dims
[
1
];
...
@@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
...
@@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
{
static_cast
<
int64_t
>
(
out
->
rows
().
size
()),
row_width
});
{
static_cast
<
int64_t
>
(
out
->
rows
().
size
()),
row_width
});
T
*
output
=
out
->
mutable_value
()
->
mutable_data
<
T
>
(
ddim
,
place
);
T
*
output
=
out
->
mutable_value
()
->
mutable_data
<
T
>
(
ddim
,
place
);
for
(
int64_t
i
=
0
;
i
<
ddim
[
0
];
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
ddim
[
0
];
++
i
)
{
memcpy
(
output
+
i
*
row_width
,
ids
+
out
->
rows
()[
i
]
*
row_width
,
memcpy
(
output
+
i
*
row_width
,
ids
+
id_to_index
[
out
->
rows
()[
i
]]
*
row_width
,
row_width
*
sizeof
(
T
));
row_width
*
sizeof
(
T
));
}
}
}
}
...
...
paddle/fluid/platform/cuda_helper_test.cu
浏览文件 @
a5c96af3
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include <iostream>
#include <random>
#include <random>
...
@@ -25,13 +24,13 @@
...
@@ -25,13 +24,13 @@
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
float16
;
using
paddle
::
platform
::
float16
;
#define CUDA_ATOMIC_KERNEL(op, T) \
template
<
typename
T
>
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
__global__
void
AddKernel
(
const
T
*
data_a
,
T
*
data_b
,
size_t
num
)
{
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i += blockDim.x * gridDim.x) { \
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
paddle
::
platform
::
CudaAtomicAdd
(
&
data_b
[
i
],
data_a
[
i
]);
} \
}
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
AddFunctor
{
struct
AddFunctor
{
...
@@ -39,80 +38,116 @@ struct AddFunctor {
...
@@ -39,80 +38,116 @@ struct AddFunctor {
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
SubFunctor
{
void
TestCase
(
size_t
num
)
{
T
operator
()(
const
T
&
a
,
const
T
&
b
)
{
return
a
-
b
;
}
T
*
in1
,
*
in2
,
*
out
;
};
T
*
d_in1
,
*
d_in2
;
size_t
size
=
sizeof
(
T
)
*
num
;
// NOTE(dzhwinter): the float16 add has small underflow/overflow
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in1
),
size
);
// so we use EXPECT_NEAR to check the result.
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in2
),
size
);
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
in1
=
reinterpret_cast
<
T
*>
(
malloc
(
size
));
void Test##T##op(size_t num) { \
in2
=
reinterpret_cast
<
T
*>
(
malloc
(
size
));
T *in1, *in2, *out; \
out
=
reinterpret_cast
<
T
*>
(
malloc
(
size
));
T *d_in1, *d_in2; \
std
::
minstd_rand
engine
;
size_t size = sizeof(T) * num; \
std
::
uniform_real_distribution
<
double
>
dist
(
0.0
,
1.0
);
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
in1
[
i
]
=
static_cast
<
T
>
(
dist
(
engine
));
in1 = reinterpret_cast<T*>(malloc(size)); \
in2
[
i
]
=
static_cast
<
T
>
(
dist
(
engine
));
in2 = reinterpret_cast<T*>(malloc(size)); \
out = reinterpret_cast<T*>(malloc(size)); \
std::minstd_rand engine; \
std::uniform_real_distribution<double> dist(0.0, 1.0); \
for (size_t i = 0; i < num; ++i) { \
in1[i] = static_cast<T>(dist(engine)); \
in2[i] = static_cast<T>(dist(engine)); \
} \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
cudaDeviceSynchronize(); \
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
cudaDeviceSynchronize(); \
for (size_t i = 0; i < num; ++i) { \
EXPECT_NEAR(static_cast<float>(out[i]), \
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
0.001); \
} \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
}
CUDA_ATOMIC_KERNEL
(
Add
,
float
);
cudaMemcpy
(
d_in1
,
in1
,
size
,
cudaMemcpyHostToDevice
);
CUDA_ATOMIC_KERNEL
(
Add
,
double
);
cudaMemcpy
(
d_in2
,
in2
,
size
,
cudaMemcpyHostToDevice
);
CUDA_ATOMIC_KERNEL
(
Add
,
float16
);
AddKernel
<
T
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
d_in1
,
d_in2
,
num
);
cudaDeviceSynchronize
();
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
float
);
cudaMemcpy
(
out
,
d_in2
,
size
,
cudaMemcpyDeviceToHost
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
double
);
cudaDeviceSynchronize
();
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
float16
);
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
// NOTE(dzhwinter): the float16 add has small underflow/overflow
namespace
paddle
{
// so we use EXPECT_NEAR to check the result.
namespace
platform
{
EXPECT_NEAR
(
static_cast
<
float
>
(
out
[
i
]),
USE_CUDA_ATOMIC
(
Sub
,
int
);
static_cast
<
float
>
(
AddFunctor
<
T
>
()(
in1
[
i
],
in2
[
i
])),
0.001
);
};
}
};
free
(
in1
);
CUDA_ATOMIC_KERNEL
(
Sub
,
int
);
free
(
in2
);
ARITHMETIC_KERNEL_LAUNCH
(
Sub
,
int
);
free
(
out
);
cudaFree
(
d_in1
);
cudaFree
(
d_in2
);
}
// cuda primitives
// cuda primitives
TEST
(
CudaAtomic
,
Add
)
{
TEST
(
CudaAtomic
,
Add
)
{
TestfloatAdd
(
static_cast
<
size_t
>
(
10
));
TestCase
<
float
>
(
static_cast
<
size_t
>
(
10
));
TestfloatAdd
(
static_cast
<
size_t
>
(
1024
*
1024
));
TestCase
<
float
>
(
static_cast
<
size_t
>
(
1024
*
1024
));
TestdoubleAdd
(
static_cast
<
size_t
>
(
10
));
TestdoubleAdd
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
TEST
(
CudaAtomic
,
Sub
)
{
TestCase
<
double
>
(
static_cast
<
size_t
>
(
10
));
TestintSub
(
static_cast
<
size_t
>
(
10
));
TestCase
<
double
>
(
static_cast
<
size_t
>
(
1024
*
1024
));
TestintSub
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
}
TEST
(
CudaAtomic
,
float16
)
{
TEST
(
CudaAtomic
,
float16
)
{
using
paddle
::
platform
::
float16
;
TestCase
<
float16
>
(
static_cast
<
size_t
>
(
1
));
Testfloat16Add
(
static_cast
<
size_t
>
(
1
));
TestCase
<
float16
>
(
static_cast
<
size_t
>
(
2
));
Testfloat16Add
(
static_cast
<
size_t
>
(
2
));
TestCase
<
float16
>
(
static_cast
<
size_t
>
(
3
));
Testfloat16Add
(
static_cast
<
size_t
>
(
3
));
TestCase
<
float16
>
(
static_cast
<
size_t
>
(
10
));
TestCase
<
float16
>
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
// unalignment of uint8
void
TestUnalign
(
size_t
num
,
const
int
shift_bit
)
{
PADDLE_ENFORCE
(
num
%
2
==
0
,
"must be a multiple of 2"
);
float16
*
in1
,
*
in2
,
*
out
;
float16
*
d_in1
,
*
d_in2
;
size_t
size
=
sizeof
(
uint8_t
)
*
(
num
+
shift_bit
);
size_t
array_size
=
sizeof
(
float16
)
*
(
num
/
2
);
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in1
),
size
);
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in2
),
size
);
in1
=
reinterpret_cast
<
float16
*>
(
malloc
(
size
));
in2
=
reinterpret_cast
<
float16
*>
(
malloc
(
size
));
out
=
reinterpret_cast
<
float16
*>
(
malloc
(
size
));
// right shift 1, mimic the unalignment of address
float16
*
r_in1
=
reinterpret_cast
<
float16
*>
(
reinterpret_cast
<
uint8_t
*>
(
in1
)
+
shift_bit
);
float16
*
r_in2
=
reinterpret_cast
<
float16
*>
(
reinterpret_cast
<
uint8_t
*>
(
in2
)
+
shift_bit
);
std
::
minstd_rand
engine
;
std
::
uniform_real_distribution
<
double
>
dist
(
0.0
,
1.0
);
for
(
size_t
i
=
0
;
i
<
num
/
2
;
++
i
)
{
r_in1
[
i
]
=
static_cast
<
float16
>
(
dist
(
engine
));
r_in2
[
i
]
=
static_cast
<
float16
>
(
dist
(
engine
));
}
cudaMemcpy
(
d_in1
,
r_in1
,
array_size
,
cudaMemcpyHostToDevice
);
cudaMemcpy
(
d_in2
,
r_in2
,
array_size
,
cudaMemcpyHostToDevice
);
AddKernel
<
float16
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
d_in1
,
d_in2
,
num
/
2
);
cudaDeviceSynchronize
();
cudaMemcpy
(
out
,
d_in2
,
array_size
,
cudaMemcpyDeviceToHost
);
cudaDeviceSynchronize
();
for
(
size_t
i
=
0
;
i
<
num
/
2
;
++
i
)
{
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
EXPECT_NEAR
(
static_cast
<
float
>
(
out
[
i
]),
static_cast
<
float
>
(
AddFunctor
<
float16
>
()(
r_in1
[
i
],
r_in2
[
i
])),
0.001
);
}
free
(
in1
);
free
(
in2
);
free
(
out
);
cudaFree
(
d_in1
);
cudaFree
(
d_in2
);
}
TEST
(
CudaAtomic
,
float16Unalign
)
{
// same with float16 testcase
TestUnalign
(
static_cast
<
size_t
>
(
2
),
/*shift_bit*/
2
);
TestUnalign
(
static_cast
<
size_t
>
(
1024
),
/*shift_bit*/
2
);
TestUnalign
(
static_cast
<
size_t
>
(
1024
*
1024
),
/*shift_bit*/
2
);
// shift the address.
TestUnalign
(
static_cast
<
size_t
>
(
2
),
/*shift_bit*/
1
);
TestUnalign
(
static_cast
<
size_t
>
(
1024
),
/*shift_bit*/
1
);
TestUnalign
(
static_cast
<
size_t
>
(
1024
*
1024
),
/*shift_bit*/
1
);
Testfloat16Add
(
static_cast
<
size_t
>
(
10
));
TestUnalign
(
static_cast
<
size_t
>
(
2
),
/*shift_bit*/
3
);
Testfloat16Add
(
static_cast
<
size_t
>
(
1024
*
1024
));
TestUnalign
(
static_cast
<
size_t
>
(
1024
),
/*shift_bit*/
3
);
TestUnalign
(
static_cast
<
size_t
>
(
1024
*
1024
),
/*shift_bit*/
3
);
}
}
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
a5c96af3
...
@@ -79,41 +79,41 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
...
@@ -79,41 +79,41 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
// convert the value into float and do the add arithmetic.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
// then store the result into a uint32.
inline
__device__
uint32_t
add_to_low_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
add_to_low_half
(
uint32_t
val
,
float
x
)
{
float16
low_half
;
float16
low_half
;
// the float16 in lower 16bits
// the float16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0x
ffff
u
);
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0x
FFFF
u
);
low_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
low_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
return
(
val
&
0x
ffff
0000u
)
|
low_half
.
x
;
return
(
val
&
0x
FFFF
0000u
)
|
low_half
.
x
;
}
}
inline
__device__
uint32_t
add_to_high_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
add_to_high_half
(
uint32_t
val
,
float
x
)
{
float16
high_half
;
float16
high_half
;
// the float16 in higher 16bits
// the float16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
high_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
return
(
val
&
0x
ffff
u
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
return
(
val
&
0x
FFFF
u
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
}
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
// concrete packed float16 value may exsits in lower or higher 16bits
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
// of the 32bits address.
uint32_t
*
address_as_ui
=
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
address
)
-
reinterpret_cast
<
char
*>
(
address
)
-
(
reinterpret_cast
<
size_t
>
(
address
)
&
2
));
(
reinterpret_cast
<
uintptr_t
>
(
address
)
&
0x0
2
));
float
val_f
=
static_cast
<
float
>
(
val
);
float
val_f
=
static_cast
<
float
>
(
val
);
uint32_t
old
=
*
address_as_ui
;
uint32_t
old
=
*
address_as_ui
;
uint32_t
sum
;
uint32_t
sum
;
uint32_t
newval
;
uint32_t
newval
;
uint32_t
assumed
;
uint32_t
assumed
;
if
(((
size_t
)
address
&
2
)
==
0
)
{
if
(((
uintptr_t
)
address
&
0x0
2
)
==
0
)
{
// the float16 value stay at lower 16 bits of the address.
// the float16 value stay at lower 16 bits of the address.
do
{
do
{
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_low_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
float16
ret
;
ret
.
x
=
old
&
0x
ffff
u
;
ret
.
x
=
old
&
0x
FFFF
u
;
return
ret
;
return
ret
;
}
else
{
}
else
{
// the float16 value stay at higher 16 bits of the address.
// the float16 value stay at higher 16 bits of the address.
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
a5c96af3
...
@@ -534,7 +534,7 @@ EOF
...
@@ -534,7 +534,7 @@ EOF
make
-j
`
nproc
`
inference_lib_dist
make
-j
`
nproc
`
inference_lib_dist
cd
${
PADDLE_ROOT
}
/build
cd
${
PADDLE_ROOT
}
/build
cp
-r
fluid_install_dir fluid
cp
-r
fluid_install_dir fluid
tar
-cf
fluid.tgz fluid
tar
-c
z
f
fluid.tgz fluid
fi
fi
}
}
...
...
python/paddle/fluid/__init__.py
浏览文件 @
a5c96af3
...
@@ -127,6 +127,7 @@ def __bootstrap__():
...
@@ -127,6 +127,7 @@ def __bootstrap__():
]
]
if
core
.
is_compiled_with_dist
():
if
core
.
is_compiled_with_dist
():
read_env_flags
.
append
(
'rpc_deadline'
)
read_env_flags
.
append
(
'rpc_deadline'
)
read_env_flags
.
append
(
'listen_and_serv_profile_period'
)
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
read_env_flags
+=
[
read_env_flags
+=
[
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
a5c96af3
...
@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper, unique_name
...
@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper, unique_name
from
..initializer
import
force_init_on_cpu
from
..initializer
import
force_init_on_cpu
from
ops
import
logical_and
,
logical_not
,
logical_or
from
ops
import
logical_and
,
logical_not
,
logical_or
import
numpy
import
numpy
import
warnings
__all__
=
[
__all__
=
[
'While'
,
'While'
,
...
@@ -280,6 +281,9 @@ class ParallelDo(object):
...
@@ -280,6 +281,9 @@ class ParallelDo(object):
"""
"""
def
__init__
(
self
,
places
,
use_nccl
=
False
,
name
=
None
):
def
__init__
(
self
,
places
,
use_nccl
=
False
,
name
=
None
):
warnings
.
warn
(
"API ParallelDo is deprecated since 0.15.0. Please use ParallelExecutor instead."
,
Warning
)
self
.
helper
=
LayerHelper
(
"parallel_do"
,
name
=
name
)
self
.
helper
=
LayerHelper
(
"parallel_do"
,
name
=
name
)
self
.
inputs
=
[]
self
.
inputs
=
[]
self
.
places
=
places
self
.
places
=
places
...
@@ -338,7 +342,7 @@ class ParallelDo(object):
...
@@ -338,7 +342,7 @@ class ParallelDo(object):
return
[
parent_block
.
var
(
name
)
for
name
in
params
]
return
[
parent_block
.
var
(
name
)
for
name
in
params
]
def
complete_op
(
self
):
def
_
complete_op
(
self
):
main_program
=
self
.
helper
.
main_program
main_program
=
self
.
helper
.
main_program
current_block
=
main_program
.
current_block
()
current_block
=
main_program
.
current_block
()
parent_block
=
self
.
parent_block
()
parent_block
=
self
.
parent_block
()
...
@@ -394,7 +398,7 @@ class BlockGuardWithCompletion(BlockGuard):
...
@@ -394,7 +398,7 @@ class BlockGuardWithCompletion(BlockGuard):
if
exc_type
is
not
None
:
if
exc_type
is
not
None
:
return
False
return
False
self
.
rnn
.
status
=
StaticRNN
.
AFTER_RNN_BLOCK
self
.
rnn
.
status
=
StaticRNN
.
AFTER_RNN_BLOCK
self
.
rnn
.
complete_op
()
self
.
rnn
.
_
complete_op
()
return
super
(
BlockGuardWithCompletion
,
self
).
__exit__
(
exc_type
,
exc_val
,
return
super
(
BlockGuardWithCompletion
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
exc_tb
)
...
@@ -470,7 +474,7 @@ class StaticRNN(object):
...
@@ -470,7 +474,7 @@ class StaticRNN(object):
if
shape
is
None
or
batch_ref
is
None
:
if
shape
is
None
or
batch_ref
is
None
:
raise
ValueError
(
raise
ValueError
(
"if init is None, memory at least need shape and batch_ref"
)
"if init is None, memory at least need shape and batch_ref"
)
parent_block
=
self
.
parent_block
()
parent_block
=
self
.
_
parent_block
()
var_name
=
unique_name
.
generate
(
"@"
.
join
(
var_name
=
unique_name
.
generate
(
"@"
.
join
(
[
self
.
helper
.
name
,
"memory_boot"
]))
[
self
.
helper
.
name
,
"memory_boot"
]))
boot_var
=
parent_block
.
create_var
(
boot_var
=
parent_block
.
create_var
(
...
@@ -527,7 +531,7 @@ class StaticRNN(object):
...
@@ -527,7 +531,7 @@ class StaticRNN(object):
outputs
=
{
'Out'
:
tmp_o
},
outputs
=
{
'Out'
:
tmp_o
},
attrs
=
{
'dtype'
:
o
.
dtype
})
attrs
=
{
'dtype'
:
o
.
dtype
})
out_var
=
self
.
parent_block
().
create_var
(
out_var
=
self
.
_
parent_block
().
create_var
(
name
=
tmp_o
.
name
,
name
=
tmp_o
.
name
,
shape
=
[
self
.
seq_len
]
+
list
(
tmp_o
.
shape
),
shape
=
[
self
.
seq_len
]
+
list
(
tmp_o
.
shape
),
dtype
=
tmp_o
.
dtype
)
dtype
=
tmp_o
.
dtype
)
...
@@ -543,7 +547,7 @@ class StaticRNN(object):
...
@@ -543,7 +547,7 @@ class StaticRNN(object):
raise
TypeError
(
"update memory should take variables"
)
raise
TypeError
(
"update memory should take variables"
)
self
.
memories
[
mem
.
name
].
mem
=
var
self
.
memories
[
mem
.
name
].
mem
=
var
def
parent_block
(
self
):
def
_
parent_block
(
self
):
prog
=
self
.
helper
.
main_program
prog
=
self
.
helper
.
main_program
parent_idx
=
prog
.
current_block
().
parent_idx
parent_idx
=
prog
.
current_block
().
parent_idx
assert
parent_idx
>=
0
assert
parent_idx
>=
0
...
@@ -560,10 +564,10 @@ class StaticRNN(object):
...
@@ -560,10 +564,10 @@ class StaticRNN(object):
else
:
else
:
return
self
.
outputs
return
self
.
outputs
def
complete_op
(
self
):
def
_
complete_op
(
self
):
main_program
=
self
.
helper
.
main_program
main_program
=
self
.
helper
.
main_program
rnn_block
=
main_program
.
current_block
()
rnn_block
=
main_program
.
current_block
()
parent_block
=
self
.
parent_block
()
parent_block
=
self
.
_
parent_block
()
local_inputs
=
set
()
local_inputs
=
set
()
...
@@ -643,7 +647,7 @@ class WhileGuard(BlockGuard):
...
@@ -643,7 +647,7 @@ class WhileGuard(BlockGuard):
if
exc_type
is
not
None
:
if
exc_type
is
not
None
:
return
False
return
False
self
.
while_op
.
status
=
While
.
AFTER_WHILE_BLOCK
self
.
while_op
.
status
=
While
.
AFTER_WHILE_BLOCK
self
.
while_op
.
complete
()
self
.
while_op
.
_
complete
()
return
super
(
WhileGuard
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
super
(
WhileGuard
,
self
).
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
...
@@ -690,7 +694,7 @@ class While(object):
...
@@ -690,7 +694,7 @@ class While(object):
def
block
(
self
):
def
block
(
self
):
return
WhileGuard
(
self
)
return
WhileGuard
(
self
)
def
complete
(
self
):
def
_
complete
(
self
):
main_program
=
self
.
helper
.
main_program
main_program
=
self
.
helper
.
main_program
while_block
=
main_program
.
current_block
()
while_block
=
main_program
.
current_block
()
parent_block
=
main_program
.
block
(
main_program
.
current_block
()
parent_block
=
main_program
.
block
(
main_program
.
current_block
()
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
a5c96af3
...
@@ -4473,15 +4473,14 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
...
@@ -4473,15 +4473,14 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension."
)
"except one unknown dimension."
)
helper
=
LayerHelper
(
"reshape"
,
**
locals
())
helper
=
LayerHelper
(
"reshape"
,
**
locals
())
reshaped
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_tmp_variable
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"reshape"
,
type
=
"reshape"
,
inputs
=
inputs
,
inputs
=
inputs
,
attrs
=
{
"shape"
:
shape
,
attrs
=
{
"shape"
:
shape
},
"inplace"
:
inplace
},
outputs
=
{
"Out"
:
out
})
outputs
=
{
"Out"
:
reshaped
})
return
helper
.
append_activation
(
reshaped
)
return
helper
.
append_activation
(
out
)
def
lod_reset
(
x
,
y
=
None
,
target_lod
=
None
):
def
lod_reset
(
x
,
y
=
None
,
target_lod
=
None
):
...
...
python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py
浏览文件 @
a5c96af3
...
@@ -43,5 +43,29 @@ class TestControlFlowGraph(unittest.TestCase):
...
@@ -43,5 +43,29 @@ class TestControlFlowGraph(unittest.TestCase):
print
(
str
(
result_program
))
print
(
str
(
result_program
))
class
TestMemoryTranspiler2
(
unittest
.
TestCase
):
def
setUp
(
self
):
program
=
Program
()
with
program_guard
(
program
,
startup_program
=
Program
()):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
fc
=
layers
.
fc
(
input
=
x
,
size
=
10
,
act
=
None
)
reshape
=
layers
.
reshape
(
x
=
fc
,
shape
=
[
-
1
,
2
,
5
])
fc
=
layers
.
reshape
(
x
=
reshape
,
shape
=
[
-
1
,
5
,
2
])
y_predict
=
layers
.
fc
(
input
=
fc
,
size
=
1
,
act
=
None
)
y
=
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
layers
.
mean
(
cost
)
opt
=
optimizer
.
SGD
(
learning_rate
=
0.001
)
opt
.
minimize
(
avg_cost
)
self
.
program
=
program
def
test_inplace_ops
(
self
):
print
(
"before optimization"
)
print
(
str
(
self
.
program
))
result_program
=
memory_optimize
(
self
.
program
)
print
(
"after optimization"
)
print
(
str
(
result_program
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_reshape_op.py
浏览文件 @
a5c96af3
...
@@ -25,7 +25,7 @@ class TestReshapeOp(OpTest):
...
@@ -25,7 +25,7 @@ class TestReshapeOp(OpTest):
self
.
op_type
=
"reshape"
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
,
"inplace"
:
False
}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
new_shape
)}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
@@ -42,7 +42,7 @@ class TestReshapeOpDimInfer1(OpTest):
...
@@ -42,7 +42,7 @@ class TestReshapeOpDimInfer1(OpTest):
self
.
op_type
=
"reshape"
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
,
"inplace"
:
False
}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
attrs
[
"shape"
])}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
self
.
attrs
[
"shape"
])}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
@@ -60,7 +60,7 @@ class TestReshapeOpDimInfer2(OpTest):
...
@@ -60,7 +60,7 @@ class TestReshapeOpDimInfer2(OpTest):
self
.
op_type
=
"reshape"
self
.
op_type
=
"reshape"
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
inputs
=
{
"X"
:
np
.
random
.
random
(
ori_shape
).
astype
(
"float32"
)}
self
.
attrs
=
{
"shape"
:
new_shape
,
"inplace"
:
False
}
self
.
attrs
=
{
"shape"
:
new_shape
}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
infered_shape
)}
self
.
outputs
=
{
"Out"
:
self
.
inputs
[
"X"
].
reshape
(
infered_shape
)}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_split_ids_op.py
浏览文件 @
a5c96af3
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
class
TestSplitIdsOp
(
OpTest
):
class
TestSplitIdsOp
(
OpTest
):
...
@@ -31,5 +33,55 @@ class TestSplitIdsOp(OpTest):
...
@@ -31,5 +33,55 @@ class TestSplitIdsOp(OpTest):
self
.
check_output
()
self
.
check_output
()
class
TestSpliteIds
(
unittest
.
TestCase
):
def
get_places
(
self
):
places
=
[
core
.
CPUPlace
()]
return
places
def
test_check_output
(
self
):
for
place
in
self
.
get_places
():
self
.
check_with_place
(
place
)
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
rows
=
[
0
,
5
,
7
,
4
,
9
]
height
=
20
row_numel
=
2
# initialize input variable X
x
=
scope
.
var
(
'X'
).
get_selected_rows
()
x
.
set_rows
(
rows
)
x
.
set_height
(
height
)
np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
for
i
in
range
(
len
(
rows
)):
for
j
in
range
(
row_numel
):
np_array
[
i
,
j
]
=
rows
[
i
]
+
j
x_tensor
=
x
.
get_tensor
()
x_tensor
.
set
(
np_array
,
place
)
outs_name
=
[
"out%d"
%
i
for
i
in
xrange
(
3
)]
outs
=
[
scope
.
var
(
var_name
).
get_selected_rows
()
for
var_name
in
outs_name
]
# expected output selected rows
expected_out_rows
=
[[
0
,
9
],
[
7
,
4
],
[
5
]]
op
=
Operator
(
"split_ids"
,
Ids
=
"X"
,
Out
=
outs_name
)
for
_
in
range
(
3
):
op
.
run
(
scope
,
place
)
for
i
in
range
(
len
(
outs
)):
expected_rows
=
expected_out_rows
[
i
]
self
.
assertEqual
(
outs
[
i
].
rows
(),
expected_rows
)
for
j
in
range
(
len
(
expected_rows
)):
row
=
expected_rows
[
j
]
self
.
assertAlmostEqual
(
float
(
row
),
np
.
array
(
outs
[
i
].
get_tensor
())[
j
,
0
])
self
.
assertAlmostEqual
(
float
(
row
+
1
),
np
.
array
(
outs
[
i
].
get_tensor
())[
j
,
1
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
a5c96af3
...
@@ -495,6 +495,7 @@ class DistributeTranspiler(object):
...
@@ -495,6 +495,7 @@ class DistributeTranspiler(object):
pserver_index
=
self
.
pserver_endpoints
.
index
(
endpoint
)
pserver_index
=
self
.
pserver_endpoints
.
index
(
endpoint
)
table_opt_block
=
self
.
_create_table_optimize_block
(
table_opt_block
=
self
.
_create_table_optimize_block
(
pserver_index
,
pserver_program
,
pre_block_idx
,
grad_to_block_id
)
pserver_index
,
pserver_program
,
pre_block_idx
,
grad_to_block_id
)
optimize_blocks
.
append
(
table_opt_block
)
prefetch_var_name_to_block_id
=
self
.
_create_prefetch_block
(
prefetch_var_name_to_block_id
=
self
.
_create_prefetch_block
(
pserver_index
,
pserver_program
,
table_opt_block
)
pserver_index
,
pserver_program
,
table_opt_block
)
checkpoint_block_id
=
self
.
_create_checkpoint_save_block
(
checkpoint_block_id
=
self
.
_create_checkpoint_save_block
(
...
...
tools/manylinux1/Dockerfile.x64
浏览文件 @
a5c96af3
...
@@ -13,7 +13,7 @@ ENV PATH /opt/rh/devtoolset-2/root/usr/bin:$PATH
...
@@ -13,7 +13,7 @@ ENV PATH /opt/rh/devtoolset-2/root/usr/bin:$PATH
ENV LD_LIBRARY_PATH /opt/rh/devtoolset-2/root/usr/lib64:/opt/rh/devtoolset-2/root/usr/lib:/usr/local/lib64:/usr/local/lib:${LD_LIBRARY_PATH}
ENV LD_LIBRARY_PATH /opt/rh/devtoolset-2/root/usr/lib64:/opt/rh/devtoolset-2/root/usr/lib:/usr/local/lib64:/usr/local/lib:${LD_LIBRARY_PATH}
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig
RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz
RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz
graphviz
COPY build_scripts /build_scripts
COPY build_scripts /build_scripts
RUN bash build_scripts/build.sh && \
RUN bash build_scripts/build.sh && \
bash build_scripts/install_nccl2.sh && rm -r build_scripts
bash build_scripts/install_nccl2.sh && rm -r build_scripts
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录