Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
aa1085dd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
aa1085dd
编写于
7月 26, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
all passes
add doc
上级
e4d7d7ae
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
87 addition
and
154 deletion
+87
-154
doc/fluid/design/ir/draft.md
doc/fluid/design/ir/draft.md
+38
-0
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+0
-3
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+5
-5
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
+0
-53
paddle/fluid/framework/details/ssa_graph_builder_factory.h
paddle/fluid/framework/details/ssa_graph_builder_factory.h
+0
-71
paddle/fluid/framework/details/ssa_graph_checker.h
paddle/fluid/framework/details/ssa_graph_checker.h
+2
-3
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+3
-5
paddle/fluid/framework/ir/pass.h
paddle/fluid/framework/ir/pass.h
+5
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+33
-13
未找到文件。
doc/fluid/design/ir/draft.md
浏览文件 @
aa1085dd
...
...
@@ -71,6 +71,44 @@ is a `Graph` and its output is also a `Graph`. For example,
a
`Pass`
can simply print out the
`Graph`
. A
`Pass`
can also fuse some
`Graph`
's
`Node`
s.
```
cpp
class
Pass
{
public:
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
;
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
;
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
);
};
// In my_pass.cc
class
MyPass
:
public
Pass
{
public:
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
override
{
// do something.
return
graph
;
}
}
REGISTER_PASS
(
my_pass
,
MyPass
);
// To use the pass.
auto
my_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"my_pass"
);
graph
=
my_pass
->
Apply
(
std
::
move
(
graph
));
// Note: to force link my_pass.cc, in the code:
USE_PASS
(
my_pass
);
```
#### Optimize
`Optimize`
contains a series of
`Pass`
with defined order.
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
aa1085dd
...
...
@@ -99,7 +99,7 @@ else()
endif
()
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph graph_viz_pass multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
aa1085dd
...
...
@@ -31,9 +31,6 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
cc_library
(
ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
aa1085dd
...
...
@@ -35,15 +35,15 @@ namespace framework {
namespace
details
{
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
loss_var_name_
=
Get
<
std
::
string
>
(
"loss_var_name"
);
places_
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
"places"
);
local_scopes_
=
Get
<
std
::
vector
<
Scope
*>>
(
"local_scopes"
);
strategy_
=
Get
<
BuildStrategy
>
(
"strategy"
);
loss_var_name_
=
Get
<
const
std
::
string
>
(
"loss_var_name"
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
);
strategy_
=
Get
<
const
BuildStrategy
>
(
"strategy"
);
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
);
#endif
for
(
auto
&
p
:
Get
<
std
::
unordered_set
<
std
::
string
>>
(
"params"
))
{
for
(
auto
&
p
:
Get
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
))
{
grad_names_
.
insert
(
GradVarName
(
p
));
}
balance_vars_
.
resize
(
places_
.
size
(),
0
);
...
...
paddle/fluid/framework/details/ssa_graph_builder_factory.cc
已删除
100644 → 0
浏览文件 @
e4d7d7ae
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
std
::
unique_ptr
<
ir
::
Pass
>
ParallelExecutorPassManager
::
Create
()
{
std
::
unique_ptr
<
ir
::
Pass
>
res
(
new
MultiDevSSAGraphBuilder
);
res
->
SetNotOwned
<
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places_
);
res
->
SetNotOwned
<
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name_
);
res
->
SetNotOwned
<
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names_
);
res
->
SetNotOwned
<
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes_
);
res
->
SetNotOwned
<
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
#ifdef PADDLE_WITH_CUDA
res
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nccl_ctxs_
);
#endif
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
ir
::
Pass
*
previous_pass
=
res
.
release
();
res
.
reset
(
new
SSAGraghBuilderWithPrinter
);
res
->
Set
<
ir
::
Pass
>
(
"previous_pass"
,
previous_pass
);
res
->
SetNotOwned
<
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy_
.
debug_graphviz_path_
);
res
->
Set
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
GraphvizSSAGraphPrinter
);
}
ir
::
Pass
*
previous_pass
=
res
.
release
();
res
.
reset
(
new
SSAGraghBuilderWithChecker
);
res
->
Set
<
ir
::
Pass
>
(
"previous_pass"
,
previous_pass
);
return
res
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_builder_factory.h
已删除
100644 → 0
浏览文件 @
e4d7d7ae
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
class
ParallelExecutorPassManager
{
public:
ParallelExecutorPassManager
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_
=
nullptr
;
#endif
}
#ifdef PADDLE_WITH_CUDA
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
nccl_ctxs_
=
nccl_ctxs
;
}
#endif
std
::
unique_ptr
<
ir
::
Pass
>
Create
();
private:
std
::
vector
<
platform
::
Place
>
places_
;
std
::
string
loss_var_name_
;
std
::
unordered_set
<
std
::
string
>
param_names_
;
std
::
vector
<
Scope
*>
local_scopes_
;
BuildStrategy
strategy_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_checker.h
浏览文件 @
aa1085dd
...
...
@@ -26,9 +26,8 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
Get
<
ir
::
Pass
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
PADDLE_ENFORCE
(
IsValidGraph
(
new_graph
.
get
()));
return
new_graph
;
PADDLE_ENFORCE
(
IsValidGraph
(
graph
.
get
()));
return
graph
;
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
;
...
...
paddle/fluid/framework/details/ssa_graph_printer.h
浏览文件 @
aa1085dd
...
...
@@ -39,13 +39,11 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
std
::
unique_ptr
<
ir
::
Graph
>
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
new_graph
=
Get
<
ir
::
Pass
>
(
"previous_pass"
).
Apply
(
std
::
move
(
graph
));
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
"debug_graphviz_path"
)));
new
std
::
ofstream
(
Get
<
const
std
::
string
>
(
"debug_graphviz_path"
)));
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
new_
graph
,
*
fout
);
return
new_
graph
;
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
return
graph
;
}
};
...
...
paddle/fluid/framework/ir/pass.h
浏览文件 @
aa1085dd
...
...
@@ -42,6 +42,7 @@ class Pass {
virtual
std
::
unique_ptr
<
Graph
>
Apply
(
std
::
unique_ptr
<
Graph
>
graph
)
const
=
0
;
// Get a reference to the attributed previously set.
template
<
typename
AttrType
>
AttrType
&
Get
(
const
std
::
string
&
attr_name
)
const
{
PADDLE_ENFORCE
(
attrs_
.
find
(
attr_name
)
!=
attrs_
.
end
(),
...
...
@@ -49,6 +50,7 @@ class Pass {
return
*
boost
::
any_cast
<
AttrType
*>
(
attrs_
.
at
(
attr_name
));
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template
<
typename
AttrType
>
void
Set
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
...
...
@@ -59,6 +61,8 @@ class Pass {
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template
<
typename
AttrType
>
void
SetNotOwned
(
const
std
::
string
&
attr_name
,
AttrType
*
attr
)
{
PADDLE_ENFORCE
(
attrs_
.
count
(
attr_name
)
==
0
);
...
...
@@ -127,6 +131,7 @@ struct PassRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
aa1085dd
...
...
@@ -26,7 +26,8 @@ limitations under the License. */
#endif
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder_factory.h"
#include "paddle/fluid/framework/details/ssa_graph_checker.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -43,16 +44,6 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
#else
const
BuildStrategy
&
strategy
)
{
#endif
details
::
ParallelExecutorPassManager
builder_factory
(
places
,
loss_var_name
,
param_names
,
local_scopes
,
strategy
);
if
(
use_cuda
)
{
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
nccl_ctxs
);
#else
PADDLE_THROW
(
"Not compiled with CUDA."
);
#endif
}
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
main_program
));
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
...
...
@@ -62,8 +53,37 @@ std::unique_ptr<ir::Graph> ApplyParallelExecutorPass(
graph
=
viz_pass
->
Apply
(
std
::
move
(
graph
));
}
auto
builder
=
builder_factory
.
Create
();
graph
=
builder
->
Apply
(
std
::
move
(
graph
));
auto
multi_device_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_pass"
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
"places"
,
&
places
);
multi_device_pass
->
SetNotOwned
<
const
std
::
string
>
(
"loss_var_name"
,
&
loss_var_name
);
multi_device_pass
->
SetNotOwned
<
const
std
::
unordered_set
<
std
::
string
>>
(
"params"
,
&
param_names
);
multi_device_pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
multi_device_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy
);
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
multi_device_pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
"nccl_ctxs"
,
nctx
);
#endif
graph
=
multi_device_pass
->
Apply
(
std
::
move
(
graph
));
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
multi_device_print_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_print_pass"
);
multi_device_print_pass
->
SetNotOwned
<
const
std
::
string
>
(
"debug_graphviz_path"
,
&
strategy
.
debug_graphviz_path_
);
multi_device_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
graph
=
multi_device_print_pass
->
Apply
(
std
::
move
(
graph
));
}
auto
multi_device_check_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"multi_device_check_pass"
);
graph
=
multi_device_check_pass
->
Apply
(
std
::
move
(
graph
));
if
(
!
strategy
.
debug_graphviz_path_
.
empty
())
{
auto
viz_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"graph_viz_pass"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录