Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ea669796
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea669796
编写于
1月 17, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
can run
上级
afda8401
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
65 addition
and
13 deletion
+65
-13
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-0
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+1
-0
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+4
-1
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+1
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+2
-0
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+15
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+36
-10
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
ea669796
...
...
@@ -184,7 +184,7 @@ endif()
target_link_libraries
(
executor garbage_collector
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
async_ssa_graph_executor
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
ea669796
...
...
@@ -79,6 +79,8 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library
(
parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor
)
cc_library
(
async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor
)
cc_test
(
broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle
)
cc_test
(
gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
...
...
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
ea669796
...
...
@@ -27,6 +27,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
pool_
(
places
.
size
()
>=
2
?
new
::
ThreadPool
(
places
.
size
())
:
nullptr
),
places_
(
std
::
move
(
places
)),
graphs_
(
std
::
move
(
graphs
))
{
VLOG
(
3
)
<<
"build AsyncSSAGraphExecutor"
;
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
// set the correct size of thread pool to each device.
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
ea669796
...
...
@@ -116,7 +116,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices.
void
AppendMultiDevPass
(
const
BuildStrategy
&
strategy
)
{
ir
::
Pass
*
multi_devices_pass
;
if
(
strategy_
.
is_distribution_
)
{
if
(
strategy_
.
async_mode_
)
{
multi_devices_pass
=
AppendPass
(
"async_multi_devices_pass"
).
get
();
}
else
if
(
strategy_
.
is_distribution_
)
{
multi_devices_pass
=
AppendPass
(
"dist_multi_devices_pass"
).
get
();
}
else
{
if
(
strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
ea669796
...
...
@@ -86,6 +86,7 @@ struct BuildStrategy {
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool
is_distribution_
{
false
};
bool
async_mode_
{
false
};
int
num_trainers_
{
1
};
int
trainer_id_
{
0
};
std
::
vector
<
std
::
string
>
trainers_endpoints_
;
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
ea669796
...
...
@@ -975,3 +975,5 @@ REGISTER_MULTI_DEVICES_PASS(
paddle
::
framework
::
details
::
AllReduceSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
dist_multi_devices_pass
,
paddle
::
framework
::
details
::
DistSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
async_multi_devices_pass
,
paddle
::
framework
::
details
::
AsyncSSAGraphBuilder
);
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
ea669796
...
...
@@ -55,7 +55,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool
UseGPU
()
const
;
bool
NeedCollectiveOps
()
const
;
virtual
bool
NeedCollectiveOps
()
const
;
bool
IsScaleLossOp
(
ir
::
Node
*
node
)
const
;
...
...
@@ -116,6 +116,20 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{}
};
class
AsyncSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{}
bool
NeedCollectiveOps
()
const
override
{
return
false
;
}
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
return
false
;
}
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
{}
};
class
BalanceVarSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
int
GetVarDeviceID
(
const
std
::
string
&
varname
)
const
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
ea669796
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
...
...
@@ -282,10 +283,19 @@ ParallelExecutor::ParallelExecutor(
graphs
.
push_back
(
std
::
move
(
graph
));
}
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
use_cuda_
);
graphs
.
push_back
(
std
::
move
(
graph
));
if
(
build_strategy
.
async_mode_
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
{
member_
->
places_
[
i
]},
loss_var_name
,
{
member_
->
local_scopes_
[
i
]},
member_
->
nranks_
,
member_
->
use_cuda_
);
graphs
.
push_back
(
std
::
move
(
graph
));
}
}
else
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
use_cuda_
);
graphs
.
push_back
(
std
::
move
(
graph
));
}
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
...
...
@@ -323,23 +333,31 @@ ParallelExecutor::ParallelExecutor(
"please don't pass loss_var_name."
;
}
}
if
(
build_strategy
.
enable_parallel_graph_
)
{
if
(
build_strategy
.
async_mode_
)
{
VLOG
(
3
)
<<
"use AsyncSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
AsyncSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
)));
}
else
if
(
build_strategy
.
enable_parallel_graph_
)
{
VLOG
(
3
)
<<
"use ParallelSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
ParallelSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
)));
}
else
{
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
VLOG
(
3
)
<<
"use ThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
[
0
])));
}
else
{
VLOG
(
3
)
<<
"use FastThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
[
0
])));
}
}
VLOG
(
3
)
<<
"use ScopeBufferedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
member_
->
places_
,
std
::
move
(
member_
->
executor_
)));
...
...
@@ -401,14 +419,22 @@ void ParallelExecutor::BCastParamsToDevices(
auto
local_scope
=
member_
->
local_scopes_
[
i
];
auto
*
t
=
local_scope
->
Var
(
var
)
->
GetMutable
<
LoDTensor
>
();
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if
(
member_
->
use_all_reduce_
||
member_
->
use_cuda_
||
var
==
"@LR_DECAY_COUNTER@"
)
{
auto
share_memory
=
[
&
]
{
t
->
Resize
(
dims
);
t
->
mutable_data
(
cpu
,
main_tensor
.
type
());
paddle
::
framework
::
TensorCopy
(
main_tensor
,
cpu
,
t
);
};
auto
copy_memory
=
[
&
]
{
t
->
ShareDataWith
(
main_tensor
);
};
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if
(
member_
->
build_strategy_
.
async_mode_
)
{
share_memory
();
}
else
if
(
member_
->
use_all_reduce_
||
member_
->
use_cuda_
||
var
==
"@LR_DECAY_COUNTER@"
)
{
copy_memory
();
}
else
{
t
->
ShareDataWith
(
main_tensor
);
share_memory
(
);
}
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
ea669796
...
...
@@ -1030,6 +1030,9 @@ All parameter, weight, gradient are variables in Paddle.
"is_distribution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
is_distribution_
=
b
;
})
.
def_property
(
"async_mode"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
async_mode_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
async_mode_
=
b
;
})
.
def_property
(
"memory_early_delete"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_early_delete_
;
},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录