Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
49313d40
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
49313d40
编写于
4月 03, 2018
作者:
X
Xin Pan
提交者:
GitHub
4月 03, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9548 from panyx0718/group_nccl_all_reduce
Group nccl all reduce and improve performance (~14% for 4 device resnext)
上级
d2f9e193
cf251eb8
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
118 addition
and
35 deletion
+118
-35
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
+7
-0
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+6
-0
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+50
-12
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+13
-3
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+5
-3
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+4
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-2
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+13
-3
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+15
-9
未找到文件。
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
浏览文件 @
49313d40
...
...
@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}
std
::
string
NCCLAllReduceOpHandle
::
Name
()
const
{
return
"
NCCL AllR
educe"
;
}
std
::
string
NCCLAllReduceOpHandle
::
Name
()
const
{
return
"
nccl_all_r
educe"
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
浏览文件 @
49313d40
...
...
@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -34,6 +37,10 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
std
::
string
Name
()
const
override
;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool
IsMultiDeviceTransfer
()
override
{
return
true
;
};
protected:
void
RunImpl
()
override
;
};
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
49313d40
...
...
@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -53,6 +55,10 @@ class OpHandleBase {
void
AddOutput
(
VarHandleBase
*
out
);
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual
bool
IsMultiDeviceTransfer
()
{
return
false
;
}
protected:
virtual
void
RunImpl
()
=
0
;
};
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
49313d40
...
...
@@ -23,22 +23,36 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
size_t
num_threads
,
bool
use_event
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
SSAGraph
>
&&
graph
)
std
::
unique_ptr
<
SSAGraph
>
&&
graph
,
bool
allow_op_delay
)
:
SSAGraphExecutor
(
std
::
move
(
graph
)),
pool_
(
num_threads
>=
2
?
new
::
ThreadPool
(
num_threads
)
:
nullptr
),
local_scopes_
(
local_scopes
),
places_
(
places
),
fetch_ctxs_
(
places
),
use_event_
(
use_event
)
{}
use_event_
(
use_event
),
running_ops_
(
0
),
allow_op_delay_
(
allow_op_delay
)
{}
void
ThreadedSSAGraphExecutor
::
RunDelayedOps
(
const
std
::
unordered_set
<
OpHandleBase
*>
&
delayed_ops
)
{
for
(
auto
op
:
delayed_ops
)
{
op
->
Run
(
use_event_
);
}
}
FeedFetchList
ThreadedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
BlockingQueue
<
VarHandleBase
*>
ready_vars
;
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
std
::
unordered_set
<
OpHandleBase
*>
blocked_by_delayed_ops
;
std
::
unordered_set
<
VarHandleBase
*>
delayed_vars
;
auto
InsertPendingVar
=
[
&
pending_vars
,
&
ready_vars
](
VarHandleBase
&
var
)
{
pending_vars
.
insert
(
&
var
);
...
...
@@ -106,7 +120,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto
run_all_ready_ops
=
[
&
]
{
for
(
auto
*
op
:
ready_ops
)
{
RunOp
(
ready_vars
,
op
);
if
(
op
->
IsMultiDeviceTransfer
()
&&
allow_op_delay_
)
{
delayed_ops
.
insert
(
op
);
delayed_vars
.
insert
(
op
->
outputs_
.
begin
(),
op
->
outputs_
.
end
());
ready_vars
.
Extend
(
op
->
outputs_
);
continue
;
}
running_ops_
++
;
RunOp
(
&
ready_vars
,
op
);
}
ready_ops
.
clear
();
};
...
...
@@ -118,13 +139,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
// Step 3. Execution
while
(
!
pending_vars
.
empty
())
{
while
(
!
pending_vars
.
empty
()
||
!
ready_ops
.
empty
()
||
!
delayed_ops
.
empty
()
)
{
// 1. Run All Ready ops
run_all_ready_ops
();
// 2. Find ready variable
bool
timeout
;
auto
cur_ready_vars
=
ready_vars
.
PopAll
(
1
000
,
&
timeout
);
auto
cur_ready_vars
=
ready_vars
.
PopAll
(
1
,
&
timeout
);
if
(
timeout
)
{
if
(
exception_
)
{
...
...
@@ -141,13 +162,29 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto
&
deps
=
pending_ops
[
op
];
--
deps
;
if
(
deps
==
0
)
{
if
(
delayed_vars
.
find
(
ready_var
)
!=
delayed_vars
.
end
())
{
blocked_by_delayed_ops
.
insert
(
op
);
}
else
{
ready_ops
.
insert
(
op
);
}
}
}
}
// When there are no other ops to schedule, schedule buffered delayed
// ops and unblock other ops.
if
(
ready_ops
.
empty
()
&&
!
delayed_ops
.
empty
()
&&
running_ops_
==
0
)
{
RunDelayedOps
(
delayed_ops
);
delayed_ops
.
clear
();
for
(
auto
*
op
:
blocked_by_delayed_ops
)
{
ready_ops
.
insert
(
op
);
}
blocked_by_delayed_ops
.
clear
();
}
// Keep loop until all vars are ready.
}
PADDLE_ENFORCE
(
ready_ops
.
empty
());
PADDLE_ENFORCE
(
delayed_ops
.
empty
());
PADDLE_ENFORCE
(
blocked_by_delayed_ops
.
empty
());
++
computation_count_
;
auto
sync_computation
=
[
&
]
{
...
...
@@ -182,12 +219,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
void
ThreadedSSAGraphExecutor
::
RunOp
(
BlockingQueue
<
VarHandleBase
*>
&
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
&
ready_var_q
,
op
,
this
]
{
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
try
{
VLOG
(
10
)
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
op
->
Run
(
use_event_
);
ready_var_q
.
Extend
(
op
->
outputs_
);
running_ops_
--
;
ready_var_q
->
Extend
(
op
->
outputs_
);
}
catch
(
platform
::
EnforceNotMet
ex
)
{
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
}
catch
(...)
{
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
49313d40
...
...
@@ -14,7 +14,12 @@
#pragma once
#include <chrono>
#include <deque>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
...
...
@@ -70,7 +75,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor
(
size_t
num_threads
,
bool
use_event
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
SSAGraph
>
&&
graph
);
std
::
unique_ptr
<
SSAGraph
>
&&
graph
,
bool
allow_op_delay
);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
...
...
@@ -79,9 +85,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~
ThreadedSSAGraphExecutor
()
{}
private:
void
RunOp
(
BlockingQueue
<
VarHandleBase
*>
&
ready_var_q
,
void
RunOp
(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
);
void
RunDelayedOps
(
const
std
::
unordered_set
<
OpHandleBase
*>
&
delayed_ops
);
private:
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
Scope
*>
local_scopes_
;
...
...
@@ -89,6 +97,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform
::
DeviceContextPool
fetch_ctxs_
;
const
bool
use_event_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
std
::
atomic
<
int
>
running_ops_
;
bool
allow_op_delay_
;
size_t
computation_count_
{
0
};
size_t
max_async_computation
{
100
};
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
49313d40
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include <vector>
...
...
@@ -47,7 +48,7 @@ ParallelExecutor::ParallelExecutor(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
ProgramDesc
&
startup_program
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
)
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
bool
allow_op_delay
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
member_
->
global_scope_
=
scope
;
...
...
@@ -82,8 +83,8 @@ ParallelExecutor::ParallelExecutor(
auto
graph
=
builder
.
Build
(
main_program
);
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
num_threads
,
use_event
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)
));
num_threads
,
use_event
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
),
allow_op_delay
));
// Step 3. Create vars in each scope;
for
(
auto
*
scope
:
member_
->
local_scopes_
)
{
...
...
@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
49313d40
...
...
@@ -14,8 +14,9 @@ limitations under the License. */
#pragma once
#include <
future
>
#include <
string
>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
@@ -37,7 +38,8 @@ class ParallelExecutor {
const
std
::
unordered_set
<
std
::
string
>&
params
,
const
ProgramDesc
&
startup_program
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
);
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
bool
allow_op_delay
);
void
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
=
"fetched_var"
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
49313d40
...
...
@@ -504,10 +504,10 @@ All parameter, weight, gradient are variables in Paddle.
const
std
::
unordered_set
<
std
::
string
>
&
params
,
const
ProgramDesc
&
startup_program
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
)
{
Scope
*
scope
,
bool
allow_op_delay
)
{
new
(
&
self
)
ParallelExecutor
(
num_threads
,
use_event
,
places
,
params
,
startup_program
,
main_program
,
loss_var_name
,
scope
);
loss_var_name
,
scope
,
allow_op_delay
);
})
.
def
(
"run"
,
&
ParallelExecutor
::
Run
);
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
49313d40
...
...
@@ -21,7 +21,11 @@ __all__ = ['ParallelExecutor']
class
ParallelExecutor
(
object
):
def
__init__
(
self
,
loss_name
,
use_cuda
,
num_threads
=
None
):
def
__init__
(
self
,
loss_name
,
use_cuda
,
num_threads
=
None
,
allow_op_delay
=
False
):
places
=
[]
if
use_cuda
:
for
i
in
xrange
(
core
.
get_cuda_device_count
()):
...
...
@@ -35,7 +39,12 @@ class ParallelExecutor(object):
places
.
append
(
p
)
if
num_threads
is
None
:
num_threads
=
min
(
len
(
places
)
*
2
,
multiprocessing
.
cpu_count
())
if
use_cuda
:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
num_threads
=
len
(
places
)
else
:
min
(
len
(
places
)
*
2
,
multiprocessing
.
cpu_count
())
startup
=
framework
.
default_startup_program
()
main
=
framework
.
default_main_program
()
...
...
@@ -52,7 +61,8 @@ class ParallelExecutor(object):
startup
.
desc
,
main
.
desc
,
loss_name
,
scope
)
scope
,
allow_op_delay
)
self
.
scope
=
scope
def
run
(
self
,
fetch_list
):
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
49313d40
...
...
@@ -29,6 +29,7 @@ function(py_test_modules TARGET_NAME)
endfunction
()
# test time consuming OPs in a separate process for expliot parallism
list
(
REMOVE_ITEM TEST_OPS test_parallel_executor
)
list
(
REMOVE_ITEM TEST_OPS test_warpctc_op
)
list
(
REMOVE_ITEM TEST_OPS test_dyn_rnn
)
list
(
REMOVE_ITEM TEST_OPS test_mul_op
)
...
...
@@ -64,6 +65,7 @@ else()
endif
(
WITH_FAST_BUNDLE_TEST
)
# tests with high overhead
py_test_modules
(
test_parallel_executor MODULES test_parallel_executor
)
py_test_modules
(
test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=
${
WARPCTC_LIB_DIR
}
)
py_test_modules
(
test_train_dyn_rnn MODULES test_dyn_rnn
)
py_test_modules
(
test_mul_op MODULES test_mul_op
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
49313d40
...
...
@@ -135,18 +135,18 @@ def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
scale
,
act
=
'relu'
)
def
SE_ResNeXt152
(
batch_size
=
4
):
def
SE_ResNeXt152
Small
(
batch_size
=
2
):
img
=
fluid
.
layers
.
fill_constant
(
shape
=
[
batch_size
,
3
,
224
,
224
],
dtype
=
'float32'
,
value
=
0.0
)
label
=
fluid
.
layers
.
fill_constant
(
shape
=
[
batch_size
,
1
],
dtype
=
'int64'
,
value
=
0.0
)
conv
=
conv_bn_layer
(
input
=
img
,
num_filters
=
64
,
filter_size
=
3
,
stride
=
2
,
act
=
'relu'
)
input
=
img
,
num_filters
=
16
,
filter_size
=
3
,
stride
=
2
,
act
=
'relu'
)
conv
=
conv_bn_layer
(
input
=
conv
,
num_filters
=
64
,
filter_size
=
3
,
stride
=
1
,
act
=
'relu'
)
input
=
conv
,
num_filters
=
16
,
filter_size
=
3
,
stride
=
1
,
act
=
'relu'
)
conv
=
conv_bn_layer
(
input
=
conv
,
num_filters
=
1
28
,
filter_size
=
3
,
stride
=
1
,
act
=
'relu'
)
input
=
conv
,
num_filters
=
1
6
,
filter_size
=
3
,
stride
=
1
,
act
=
'relu'
)
conv
=
fluid
.
layers
.
pool2d
(
input
=
conv
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
...
...
@@ -184,7 +184,8 @@ class TestParallelExecutorBase(unittest.TestCase):
method
,
memory_opt
=
True
,
iter
=
10
,
batch_size
=
None
):
batch_size
=
None
,
allow_op_delay
=
False
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -194,7 +195,10 @@ class TestParallelExecutorBase(unittest.TestCase):
if
memory_opt
:
fluid
.
memory_optimize
(
main
)
exe
=
fluid
.
ParallelExecutor
(
loss_name
=
loss
.
name
,
use_cuda
=
True
)
exe
=
fluid
.
ParallelExecutor
(
loss_name
=
loss
.
name
,
use_cuda
=
True
,
allow_op_delay
=
allow_op_delay
)
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
()
begin
=
time
.
time
()
...
...
@@ -222,7 +226,7 @@ class TestMNIST(TestParallelExecutorBase):
def
setUpClass
(
cls
):
# Convert mnist to recordio file
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
32
)
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
4
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
# order is image and label
fluid
.
layers
.
data
(
...
...
@@ -236,9 +240,11 @@ class TestMNIST(TestParallelExecutorBase):
def
test_simple_fc
(
self
):
self
.
check_network_convergence
(
simple_fc_net
)
self
.
check_network_convergence
(
simple_fc_net
,
allow_op_delay
=
True
)
def
test_batchnorm_fc
(
self
):
self
.
check_network_convergence
(
fc_with_batchnorm
)
self
.
check_network_convergence
(
fc_with_batchnorm
,
allow_op_delay
=
True
)
class
TestResnet
(
TestParallelExecutorBase
):
...
...
@@ -262,10 +268,10 @@ class TestResnet(TestParallelExecutorBase):
def
test_resnet
(
self
):
import
functools
batch_size
=
4
batch_size
=
2
self
.
check_network_convergence
(
functools
.
partial
(
SE_ResNeXt152
,
batch_size
=
batch_size
),
SE_ResNeXt152
Small
,
batch_size
=
batch_size
),
iter
=
20
,
batch_size
=
batch_size
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录