Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
7b73fc9e
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b73fc9e
编写于
1月 09, 2019
作者:
X
Xin Pan
提交者:
GitHub
1月 09, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15089 from panyx0718/api
try unify Executor and ParallelExecutor
上级
223cc89f
c4b09a71
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
388 addition
and
81 deletion
+388
-81
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+4
-4
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+1
-2
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+1
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-2
python/paddle/fluid/compiler.py
python/paddle/fluid/compiler.py
+204
-0
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+137
-24
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+3
-4
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+13
-20
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+11
-12
python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py
...ests/unittests/test_parallel_executor_test_while_train.py
+13
-13
未找到文件。
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
7b73fc9e
...
...
@@ -193,15 +193,14 @@ ParallelExecutor::ParallelExecutor(
const
std
::
unordered_set
<
std
::
string
>
&
bcast_vars
,
const
ProgramDesc
&
main_program
,
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
,
size_t
num_trainers
,
size_t
trainer_id
)
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
)
:
member_
(
new
ParallelExecutorPrivate
(
places
))
{
member_
->
global_scope_
=
scope
;
member_
->
use_cuda_
=
exec_strategy
.
use_cuda_
;
member_
->
build_strategy_
=
build_strategy
;
member_
->
use_all_reduce_
=
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
member_
->
nranks_
=
num_trainers
*
places
.
size
();
member_
->
nranks_
=
build_strategy
.
num_trainers_
*
places
.
size
();
if
(
!
member_
->
use_all_reduce_
)
{
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
...
...
@@ -253,7 +252,8 @@ ParallelExecutor::ParallelExecutor(
}
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
,
nccl_id
,
num_trainers
,
trainer_id
));
member_
->
places_
,
nccl_id
,
build_strategy
.
num_trainers_
,
build_strategy
.
trainer_id_
));
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
7b73fc9e
...
...
@@ -50,8 +50,7 @@ class ParallelExecutor {
const
std
::
string
&
loss_var_name
,
Scope
*
scope
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
,
size_t
num_trainers
=
1
,
size_t
trainer_id
=
0
);
const
BuildStrategy
&
build_strategy
);
~
ParallelExecutor
();
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
7b73fc9e
...
...
@@ -77,6 +77,7 @@ class PreparedOp {
framework
::
OperatorWithKernel
::
OpKernelFunc
func
;
platform
::
DeviceContext
*
dev_ctx
;
};
class
OpBase
;
class
VarBase
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
7b73fc9e
...
...
@@ -1019,8 +1019,7 @@ All parameter, weight, gradient are variables in Paddle.
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
const
std
::
unordered_set
<
std
::
string
>
&
,
const
ProgramDesc
&
,
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
size_t
,
size_t
>
())
const
ExecutionStrategy
&
,
const
BuildStrategy
&>
())
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
...
...
python/paddle/fluid/compiler.py
0 → 100644
浏览文件 @
7b73fc9e
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
multiprocessing
import
os
import
six
import
sys
from
..
import
compat
as
cpt
from
.
import
core
ExecutionStrategy
=
core
.
ParallelExecutor
.
ExecutionStrategy
BuildStrategy
=
core
.
ParallelExecutor
.
BuildStrategy
def
_place_obj
(
place
):
p
=
core
.
Place
()
p
.
set_place
(
place
)
return
p
class
CompiledProgram
(
object
):
"""
Compiles a Program for execution.
1. Users first create the program with layers.
2. Optionally, users use CompiledProgram to optimize the program before run.
3. The original program or CompiledProgram is run by executor.
The CompiledProgram is used to transform a program for various
optimizations, for example.
* Pre-compute some logic once so that each run is faster.
* Transform the program so that it can run in multiple devices.
* TODO: transform the program for optimized inference or distributed
training.
Example:
.. code-block:: python
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name)
for i in range(5):
test_loss, = exe.run(compiled_prog,
feed=feed_dict,
fetch_list=[loss.name])
Args:
program: Program instance that contains the model logic.
"""
def
__init__
(
self
,
program
):
self
.
_program
=
program
self
.
_scope
=
None
self
.
_place
=
None
self
.
_executor
=
None
self
.
_compiled
=
False
self
.
_is_data_parallel
=
False
def
with_data_parallel
(
self
,
loss_name
=
None
,
build_strategy
=
None
,
exec_strategy
=
None
,
share_vars_from
=
None
):
"""Configs the program to run in data parallel way.
Args:
loss_name (str): The loss name must set in training. Default None.
build_strategy(BuildStrategy): build_strategy is used to
build the graph so it can run on multiple devices/cores with
optimized topology.
For more information, please refer to fluid.BuildStrategy.
Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to
to select the a way to execute the graph, for example how many
threads are used, how many iterations to clean up the temp
variables. For more information, please refer
to fluid.ExecutionStrategy. Default None.
share_vars_from(CompiledProgram): If provide, this CompiledProgram
will share variables from `share_vars_from`. `share_vars_from`
must be run by the executor before this CompiledProgram so that
vars are ready.
Returns:
self
"""
assert
not
self
.
_is_data_parallel
,
"Already compiled with parallel."
self
.
_is_data_parallel
=
True
self
.
_build_strategy
=
build_strategy
self
.
_exec_strategy
=
exec_strategy
self
.
_loss_name
=
loss_name
self
.
_share_vars_from
=
share_vars_from
if
self
.
_exec_strategy
is
None
:
self
.
_exec_strategy
=
ExecutionStrategy
()
if
self
.
_build_strategy
is
None
:
self
.
_build_strategy
=
BuildStrategy
()
return
self
def
_with_distributed
(
self
):
raise
NotImplementedError
()
def
_with_inference_optimize
(
self
):
raise
NotImplementedError
()
def
_compile_data_parallel
(
self
):
if
self
.
_share_vars_from
:
if
self
.
_scope
:
sys
.
stderr
.
write
(
"share_vars_from is set, scope is ignored.
\n
"
)
if
not
self
.
_share_vars_from
.
_is_data_parallel
:
raise
ValueError
(
"share_vars_from is not data parallel. Cannot "
"share vars from it."
)
if
self
.
_share_vars_from
.
_executor
is
None
:
raise
ValueError
(
"share_vars_from is not compiled and run, so there is no "
"var to share."
)
self
.
_local_scopes
=
self
.
_share_vars_from
.
_executor
.
local_scopes
()
else
:
self
.
_local_scopes
=
[]
self
.
_exec_strategy
.
use_cuda
=
isinstance
(
self
.
_place
,
core
.
CUDAPlace
)
if
self
.
_exec_strategy
.
use_cuda
:
gpus_env
=
os
.
getenv
(
"FLAGS_selected_gpus"
)
if
gpus_env
:
gpus
=
[
int
(
s
)
for
s
in
gpus_env
.
split
(
","
)]
else
:
gpus
=
[
i
for
i
in
six
.
moves
.
range
(
core
.
get_cuda_device_count
())
]
self
.
_places
=
[
core
.
CUDAPlace
(
i
)
for
i
in
gpus
]
else
:
cpu_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
self
.
_places
=
[
core
.
CPUPlace
()
for
_
in
six
.
moves
.
range
(
cpu_num
)]
assert
self
.
_places
,
"no place for execution"
if
self
.
_exec_strategy
.
num_threads
==
0
:
if
self
.
_exec_strategy
.
use_cuda
:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
self
.
_exec_strategy
.
num_threads
=
len
(
self
.
_places
)
*
4
else
:
cpu_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
self
.
_exec_strategy
.
num_threads
=
cpu_num
*
2
trainers_endpoints
=
self
.
_program
.
_trainers_endpoints
if
self
.
_build_strategy
.
num_trainers
>
1
and
trainers_endpoints
:
assert
self
.
_build_strategy
.
num_trainers
==
len
(
trainers_endpoints
),
"num_trainers == len(end_points)"
self
.
_build_strategy
.
trainers_endpoints
=
trainers_endpoints
self
.
_persistable_vars
=
set
([
cpt
.
to_text
(
v
.
name
)
for
v
in
[
var
for
var
in
self
.
_program
.
list_vars
()
if
var
.
persistable
and
var
.
type
!=
core
.
VarDesc
.
VarType
.
RAW
]
])
places
=
list
(
map
(
_place_obj
,
self
.
_places
))
return
core
.
ParallelExecutor
(
places
,
self
.
_persistable_vars
,
self
.
_program
.
desc
,
cpt
.
to_text
(
self
.
_loss_name
)
if
self
.
_loss_name
else
six
.
u
(
''
),
self
.
_scope
,
self
.
_local_scopes
,
self
.
_exec_strategy
,
self
.
_build_strategy
)
def
_compile
(
self
,
scope
,
place
):
"""Compile the program based on the configs.
Args:
scope: The variables (resources) that are associated with
this compiled program.
place: The location that the compiled program will be run on.
Returns:
self
"""
if
self
.
_compiled
:
if
scope
and
self
.
_scope
!=
scope
:
raise
ValueError
(
"Cannot compile with different scope"
)
if
place
and
self
.
_place
!=
place
:
raise
ValueError
(
"Cannot compile with different place"
)
return
self
self
.
_compiled
=
True
self
.
_scope
=
scope
self
.
_place
=
place
if
self
.
_is_data_parallel
:
self
.
_executor
=
self
.
_compile_data_parallel
()
else
:
p
=
_place_obj
(
self
.
_place
)
self
.
_executor
=
core
.
Executor
(
p
)
return
self
python/paddle/fluid/executor.py
浏览文件 @
7b73fc9e
...
...
@@ -14,11 +14,15 @@
from
__future__
import
print_function
import
os
import
multiprocessing
import
numpy
as
np
import
contextlib
import
six
from
.framework
import
Program
,
default_main_program
,
Variable
from
.
import
core
from
.
import
compiler
from
..
import
compat
as
cpt
__all__
=
[
'Executor'
,
'global_scope'
,
'scope_guard'
]
...
...
@@ -204,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True):
return
tensor
def
_get_program_cache_key
(
feed
,
fetch_list
):
feed_var_names
=
list
(
feed
.
keys
())
def
_to_name_str
(
var
):
if
isinstance
(
var
,
Variable
):
return
var
.
desc
.
name
()
elif
isinstance
(
var
,
str
):
return
var
elif
isinstance
(
var
,
six
.
string_types
):
return
str
(
var
)
else
:
raise
TypeError
(
str
(
var
)
+
" should be Variable or str"
)
def
to_name_str
(
var
):
if
isinstance
(
var
,
Variable
):
return
var
.
desc
.
name
()
elif
isinstance
(
var
,
str
):
return
var
elif
isinstance
(
var
,
six
.
string_types
):
return
str
(
var
)
else
:
raise
TypeError
(
str
(
var
)
+
" should be Variable or str"
)
fetch_var_names
=
list
(
map
(
to_name_str
,
fetch_list
))
def
_get_program_cache_key
(
feed
,
fetch_list
):
feed_var_names
=
list
(
feed
.
keys
())
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetch_list
))
return
str
(
feed_var_names
+
fetch_var_names
)
...
...
@@ -266,6 +270,29 @@ class Executor(object):
But the global scope variables will be persistent through different runs.
All of ops in program will be running in sequence.
Example:
.. code-block:: python
# First create the Executor.
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Run the startup program once and only once.
# Not need to optimize/compile the startup program.
exe.run(fluid.default_startup_program())
# Run the main program directly without compile.
loss, = exe.run(fluid.default_main_program(),
feed=feed_dict,
fetch_list=[loss.name])
# Or, compiled the program and run. See `CompiledProgram` for more detail.
compiled_prog = compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name)
loss, = exe.run(compiled_prog,
feed=feed_dict,
fetch_list=[loss.name])
Args:
place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device
...
...
@@ -275,11 +302,8 @@ class Executor(object):
def
__init__
(
self
,
place
):
self
.
place
=
place
p
=
core
.
Place
()
p
.
set_place
(
place
)
self
.
executor
=
core
.
Executor
(
p
)
self
.
program_caches
=
dict
()
self
.
executor
=
None
self
.
_closed
=
False
def
_get_program_cache
(
self
,
program_cache_key
):
...
...
@@ -361,6 +385,7 @@ class Executor(object):
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
TODO(panyx0718): Why ParallelExecutor doesn't have close?
Example:
>>> cpu = core.CPUPlace()
...
...
@@ -368,10 +393,55 @@ class Executor(object):
>>> ...
>>> exe.close()
"""
if
not
self
.
_closed
:
if
not
self
.
_closed
and
self
.
executor
:
self
.
executor
.
close
()
self
.
_closed
=
True
def
_run_parallel
(
self
,
scope
,
feed
,
fetch_list
,
fetch_var_name
,
return_numpy
):
if
isinstance
(
feed
,
dict
):
feed_tensor_dict
=
dict
()
for
feed_name
in
feed
:
feed_tensor
=
feed
[
feed_name
]
if
not
isinstance
(
feed_tensor
,
core
.
LoDTensor
):
feed_tensor
=
core
.
LoDTensor
()
# always set to CPU place, since the tensor need to be splitted
# it is fast in CPU
feed_tensor
.
set
(
feed
[
feed_name
],
core
.
CPUPlace
())
feed_tensor_dict
[
feed_name
]
=
feed_tensor
self
.
executor
.
feed_and_split_tensor_into_local_scopes
(
feed_tensor_dict
)
elif
isinstance
(
feed
,
list
)
or
isinstance
(
feed
,
tuple
):
if
len
(
feed
)
!=
len
(
self
.
_places
):
raise
ValueError
(
"Feed a list of tensor, the list should be the same size as places"
)
res
=
list
()
for
i
,
each
in
enumerate
(
feed
):
if
not
isinstance
(
each
,
dict
):
raise
TypeError
(
"Each element of feed list should be a dict"
)
res_dict
=
dict
()
for
feed_name
in
each
:
tensor
=
each
[
feed_name
]
if
not
isinstance
(
tensor
,
core
.
LoDTensor
):
tmp
=
core
.
LoDTensor
()
tmp
.
set
(
tensor
,
self
.
_places
[
i
])
tensor
=
tmp
res_dict
[
feed_name
]
=
tensor
res
.
append
(
res_dict
)
self
.
executor
.
feed_tensors_into_local_scopes
(
res
)
fetch_var_names
=
list
(
map
(
_to_name_str
,
fetch_list
))
self
.
executor
.
run
(
fetch_var_names
,
fetch_var_name
)
arr
=
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
if
return_numpy
:
return
as_numpy
(
arr
)
return
[
arr
[
i
]
for
i
in
range
(
len
(
arr
))]
def
run
(
self
,
program
=
None
,
feed
=
None
,
...
...
@@ -391,8 +461,9 @@ class Executor(object):
operators in the program but not only the operators dependent by the fetch_list
Args:
program(Program): the program that need to run, if not provied, then default_main_program will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData}
program(Program|CompiledProgram): the program that need to run,
if not provided, then default_main_program will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator.
fetch_var_name(str): the name for the output variable of fetch Operator.
...
...
@@ -428,14 +499,59 @@ class Executor(object):
if
self
.
_closed
:
raise
RuntimeError
(
"Attempted to use a closed Executor"
)
if
scope
is
None
:
scope
=
global_scope
()
if
fetch_list
is
None
:
fetch_list
=
[]
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
# For backward compatibility, run directly.
if
not
compiled
:
if
not
self
.
executor
:
p
=
core
.
Place
()
p
.
set_place
(
self
.
place
)
self
.
executor
=
core
.
Executor
(
p
)
return
self
.
_run
(
program
,
feed
=
feed
,
fetch_list
=
fetch_list
,
feed_var_name
=
feed_var_name
,
fetch_var_name
=
fetch_var_name
,
scope
=
scope
,
return_numpy
=
return_numpy
,
use_program_cache
=
use_program_cache
)
program
.
_compile
(
scope
,
self
.
place
)
self
.
executor
=
program
.
_executor
if
program
.
_is_data_parallel
:
return
self
.
_run_parallel
(
scope
=
scope
,
feed
=
feed
,
fetch_list
=
fetch_list
,
fetch_var_name
=
fetch_var_name
,
return_numpy
=
return_numpy
)
else
:
# TODO(panyx0718): Can compile program to optimize executor
# performance.
return
self
.
_run
(
program
.
_program
,
feed
=
feed
,
fetch_list
=
fetch_list
,
feed_var_name
=
feed_var_name
,
fetch_var_name
=
fetch_var_name
,
scope
=
scope
,
return_numpy
=
return_numpy
,
use_program_cache
=
use_program_cache
)
def
_run
(
self
,
program
,
feed
,
fetch_list
,
feed_var_name
,
fetch_var_name
,
scope
,
return_numpy
,
use_program_cache
):
if
feed
is
None
:
feed
=
{}
if
not
isinstance
(
feed
,
dict
):
raise
TypeError
(
"feed requires dict as its Parameter. But you passed in %s"
%
(
type
(
feed
)))
if
fetch_list
is
None
:
fetch_list
=
[]
if
program
is
None
:
program
=
default_main_program
()
...
...
@@ -444,9 +560,6 @@ class Executor(object):
"Executor requires Program as its Parameter. But you passed in %s"
%
(
type
(
program
)))
if
scope
is
None
:
scope
=
global_scope
()
cache_key
=
_get_program_cache_key
(
feed
,
fetch_list
)
if
use_program_cache
:
cached_program
=
self
.
_get_program_cache
(
cache_key
)
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
7b73fc9e
...
...
@@ -181,9 +181,8 @@ class ParallelExecutor(object):
# step7: init ParallelExecutor
self
.
executor
=
core
.
ParallelExecutor
(
places
,
persistable_vars
,
main
.
desc
,
cpt
.
to_text
(
loss_name
)
if
loss_name
else
six
.
u
(
''
),
scope
,
local_scopes
,
exec_strategy
,
build_strategy
,
num_trainers
,
trainer_id
)
cpt
.
to_text
(
loss_name
)
if
loss_name
else
six
.
u
(
''
),
scope
,
local_scopes
,
exec_strategy
,
build_strategy
)
self
.
scope
=
scope
...
...
@@ -294,7 +293,7 @@ class ParallelExecutor(object):
res
.
append
(
res_dict
)
self
.
executor
.
feed_tensors_into_local_scopes
(
res
)
fetch_var_name
=
'
@FETCHED_VAR_NAME@
'
fetch_var_name
=
'
fetch
'
self
.
executor
.
run
(
fetch_list
,
fetch_var_name
)
arr
=
self
.
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
7b73fc9e
...
...
@@ -19,6 +19,7 @@ import os
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid
import
compiler
import
time
import
numpy
as
np
import
math
...
...
@@ -44,15 +45,8 @@ class TestParallelExecutorBase(unittest.TestCase):
optimizer
=
fluid
.
optimizer
.
Adam
,
use_fast_executor
=
False
,
enable_sequential_execution
=
False
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
res
=
exe
.
run
(
fetch_list
=
fetch_list
,
feed
=
feed
)
elif
isinstance
(
exe
,
fluid
.
Executor
):
if
program
is
None
:
program
=
fluid
.
default_main_program
()
res
=
exe
.
run
(
program
=
program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
else
:
raise
ValueError
(
'Unkown type exe'
)
def
run_executor
(
exe
,
binary
,
feed
,
fetch_list
):
res
=
exe
.
run
(
binary
,
feed
=
feed
,
fetch_list
=
fetch_list
)
return
res
main
=
fluid
.
Program
()
...
...
@@ -72,8 +66,8 @@ class TestParallelExecutorBase(unittest.TestCase):
fluid
.
memory_optimize
(
main
)
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
startup_
exe
=
fluid
.
Executor
(
place
)
startup_
exe
.
run
(
startup
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
allow_op_delay
=
allow_op_delay
if
use_fast_executor
:
...
...
@@ -86,15 +80,13 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
.
enable_sequential_execution
=
enable_sequential_execution
if
use_cuda
and
core
.
is_compiled_with_cuda
():
build_strategy
.
remove_unnecessary_lock
=
True
if
use_parallel_executor
:
exe
=
fluid
.
ParallelExecutor
(
use_cuda
,
binary
=
compiler
.
CompiledProgram
(
main
).
with_data_parallel
(
loss_name
=
loss
.
name
,
exec_strategy
=
exec
_strategy
,
build_strategy
=
build
_strategy
)
build_strategy
=
build
_strategy
,
exec_strategy
=
exec
_strategy
)
else
:
exe
=
fluid
.
Executor
(
place
=
place
)
binary
=
compiler
.
CompiledProgram
(
main
)
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
(
...
...
@@ -102,13 +94,14 @@ class TestParallelExecutorBase(unittest.TestCase):
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
begin
=
time
.
time
()
first_loss
,
=
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
for
i
in
range
(
iter
):
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[])
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
last_loss
,
=
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
end
=
time
.
time
()
if
batch_size
is
not
None
:
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
7b73fc9e
...
...
@@ -26,6 +26,7 @@ import pickle
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
RUN_STEP
=
10
DEFAULT_BATCH_SIZE
=
2
...
...
@@ -104,8 +105,8 @@ class TestDistRunnerBase(object):
else
:
place
=
fluid
.
CPUPlace
()
startup_
exe
=
fluid
.
Executor
(
place
)
startup_
exe
.
run
(
fluid
.
default_startup_program
())
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
strategy
=
fluid
.
ExecutionStrategy
()
strategy
.
num_threads
=
1
...
...
@@ -125,19 +126,16 @@ class TestDistRunnerBase(object):
mypass
.
set_int
(
"num_repeats"
,
args
.
batch_merge_repeat
)
if
args
.
update_method
==
"nccl2"
:
num_trainers
=
len
(
args
.
endpoints
.
split
(
","
))
trainer_id
=
args
.
trainer_id
build_stra
.
num_trainers
=
len
(
args
.
endpoints
.
split
(
","
))
build_stra
.
trainer_id
=
args
.
trainer_id
else
:
num_trainers
=
1
trainer_id
=
0
build_stra
.
num_trainers
=
1
build_stra
.
trainer_id
=
0
exe
=
fluid
.
ParallelExecutor
(
args
.
use_cuda
,
binary
=
compiler
.
CompiledProgram
(
trainer_prog
).
with_data_parallel
(
loss_name
=
avg_cost
.
name
,
exec_strategy
=
strategy
,
build_strategy
=
build_stra
,
num_trainers
=
num_trainers
,
trainer_id
=
trainer_id
)
exec_strategy
=
strategy
)
feed_var_list
=
[
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
...
...
@@ -160,7 +158,8 @@ class TestDistRunnerBase(object):
out_losses
=
[]
for
_
in
six
.
moves
.
xrange
(
RUN_STEP
):
loss
,
=
exe
.
run
(
fetch_list
=
[
avg_cost
.
name
],
loss
,
=
exe
.
run
(
binary
,
fetch_list
=
[
avg_cost
.
name
],
feed
=
feeder
.
feed
(
get_data
()))
out_losses
.
append
(
loss
[
0
])
if
six
.
PY2
:
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_test_while_train.py
浏览文件 @
7b73fc9e
...
...
@@ -15,6 +15,7 @@
from
__future__
import
print_function
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
import
paddle.fluid.core
as
core
import
numpy
as
np
import
unittest
...
...
@@ -61,22 +62,21 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
exe
.
run
(
startup
)
feed_dict
=
{
'image'
:
image
,
'label'
:
label
}
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
use_cuda
,
train_cp
=
compiler
.
CompiledProgram
(
main
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
)
test_cp
=
compiler
.
CompiledProgram
(
test_program
).
with_data_parallel
(
loss_name
=
loss
.
name
,
main_program
=
main
,
build_strategy
=
build_strategy
)
test_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
use_cuda
,
main_program
=
test_program
,
share_vars_from
=
train_exe
,
build_strategy
=
build_strategy
)
build_strategy
=
build_strategy
,
share_vars_from
=
train_cp
)
for
i
in
range
(
5
):
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
train_loss
,
=
train_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
exe
.
run
(
train_cp
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
test_loss
,
=
exe
.
run
(
test_cp
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
train_loss
,
=
exe
.
run
(
train_cp
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
avg_test_loss_val
=
np
.
array
(
test_loss
).
mean
()
if
math
.
isnan
(
float
(
avg_test_loss_val
)):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录