Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4278518f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2310
Star
20933
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4278518f
编写于
8月 22, 2019
作者:
C
chengduo
提交者:
GitHub
8月 22, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update CompiledProgram (#18919)
* use PE for compiler test=develop
上级
9240e532
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
129 addition
and
89 deletion
+129
-89
paddle/fluid/API.spec
paddle/fluid/API.spec
+3
-3
python/paddle/fluid/compiler.py
python/paddle/fluid/compiler.py
+81
-35
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+8
-23
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+16
-17
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+1
-1
python/paddle/fluid/tests/unittests/test_eager_deletion_dynamic_rnn_base.py
...d/tests/unittests/test_eager_deletion_dynamic_rnn_base.py
+3
-2
python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py
...fluid/tests/unittests/test_eager_deletion_recurrent_op.py
+2
-1
python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py
...dle/fluid/tests/unittests/test_eager_deletion_while_op.py
+3
-2
python/paddle/fluid/tests/unittests/test_inference_model_io.py
...n/paddle/fluid/tests/unittests/test_inference_model_io.py
+1
-2
python/paddle/fluid/tests/unittests/test_py_func_op.py
python/paddle/fluid/tests/unittests/test_py_func_op.py
+8
-1
python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py
...le/fluid/tests/unittests/test_py_reader_using_executor.py
+3
-2
未找到文件。
paddle/fluid/API.spec
浏览文件 @
4278518f
...
...
@@ -47,9 +47,9 @@ paddle.fluid.DataFeedDesc.desc (ArgSpec(args=['self'], varargs=None, keywords=No
paddle.fluid.DataFeedDesc.set_batch_size (ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', 'a34790bff4a2891713ddd644db56418d'))
paddle.fluid.DataFeedDesc.set_dense_slots (ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'fdd07ce63e72bed57f2c0db5bec5720f'))
paddle.fluid.DataFeedDesc.set_use_slots (ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None), ('document', 'c23a79dfa04edd014b477bd4b183da06'))
paddle.fluid.CompiledProgram ('paddle.fluid.compiler.CompiledProgram', ('document', '
6c45b5ccc24ae62d10115ce8abdc29a5
'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'
], varargs=None, keywords=None, defaults=None
), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', '
0e17773521634ef798fddd7d2ea3ef96
'))
paddle.fluid.CompiledProgram ('paddle.fluid.compiler.CompiledProgram', ('document', '
598d294107d44d7620bce76527a92c37
'))
paddle.fluid.CompiledProgram.__init__ (ArgSpec(args=['self', 'program_or_graph'
, 'build_strategy'], varargs=None, keywords=None, defaults=(None,)
), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.CompiledProgram.with_data_parallel (ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from', 'places'], varargs=None, keywords=None, defaults=(None, None, None, None, None)), ('document', '
1c7c6171bbf6d77f2fce0166aa0ec43b
'))
paddle.fluid.CompiledProgram.with_inference_optimize (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None), ('document', '9e5b009d850191a010e859189c127fd8'))
paddle.fluid.ExecutionStrategy ('paddle.fluid.core_avx.ExecutionStrategy', ('document', '535ce28c4671176386e3cd283a764084'))
paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core_avx.ParallelExecutor.ExecutionStrategy) -> None
...
...
python/paddle/fluid/compiler.py
浏览文件 @
4278518f
...
...
@@ -45,6 +45,14 @@ def _is_pserver_mode(main_program):
return
False
def
_has_backward_op
(
graph
):
for
node
in
graph
.
nodes
():
if
node
.
is_op
()
and
node
.
op
()
is
not
None
and
\
node
.
op
().
type
().
endswith
(
"_grad"
):
return
True
return
False
def
_prune_feed_ops
(
program
):
# prune the feed ops in the program.
pop_idx
=
[]
...
...
@@ -101,9 +109,13 @@ class CompiledProgram(object):
(potentially optimized before), it will be directly used for
further optimizations. Note: graph is only supported when compiled
with with_data_parallel option.
build_strategy(BuildStrategy): build_strategy is used to
build the graph with the specified options.
For more information, please refer to fluid.BuildStrategy.
Default None.
"""
def
__init__
(
self
,
program_or_graph
):
def
__init__
(
self
,
program_or_graph
,
build_strategy
=
None
):
if
isinstance
(
program_or_graph
,
core
.
Graph
):
self
.
_graph
=
program_or_graph
# don't not create a new program here.
...
...
@@ -122,6 +134,11 @@ class CompiledProgram(object):
self
.
_compiled
=
False
self
.
_is_data_parallel
=
False
self
.
_is_inference
=
False
self
.
_loss_name
=
None
self
.
_share_vars_from
=
None
self
.
_places
=
None
self
.
_build_strategy
=
build_strategy
self
.
_exec_strategy
=
None
def
with_data_parallel
(
self
,
loss_name
=
None
,
...
...
@@ -172,9 +189,11 @@ class CompiledProgram(object):
Args:
loss_name (str): The loss name must set in training. Default None.
build_strategy(BuildStrategy): build_strategy is used to
build the graph so it can run on multiple devices/cores with
optimized topology.
build the graph with the specified options.
For more information, please refer to fluid.BuildStrategy.
Note that, if you set build_strategy in the argument list when
creating CompiledProgram and calling with_data_parallel,
the build_strategy in CompiledProgram will be overwritten by the latter.
Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to
to select the a way to execute the graph, for example how many
...
...
@@ -199,21 +218,23 @@ class CompiledProgram(object):
assert
not
self
.
_is_data_parallel
,
"Already compiled with parallel."
assert
not
self
.
_is_inference
,
"Cannot compile both data parallel and inference"
self
.
_is_data_parallel
=
True
self
.
_build_strategy
=
build_strategy
# FIXME(zcd): Currently, the build_strategy can be set during creating
# CompiledProgram or calling with_data_parallel, and it may be confusing,
# but in the long run, we should set up build_strategy only when creating
# CompiledProgram, and exec_strategy should be deprecated.
if
build_strategy
is
not
None
:
self
.
_build_strategy
=
build_strategy
self
.
_exec_strategy
=
exec_strategy
self
.
_loss_name
=
loss_name
self
.
_share_vars_from
=
share_vars_from
if
self
.
_exec_strategy
is
None
:
self
.
_exec_strategy
=
ExecutionStrategy
()
if
self
.
_build_strategy
is
None
:
self
.
_build_strategy
=
BuildStrategy
()
if
places
is
not
None
:
if
not
isinstance
(
places
,
(
list
,
tuple
)):
places
=
[
places
]
self
.
_places
=
places
else
:
self
.
_places
=
None
self
.
_build_strategy
.
is_distribution
=
_is_pserver_mode
(
self
.
_program
)
self
.
_places
=
places
if
_has_backward_op
(
self
.
_graph
):
assert
self
.
_loss_name
is
not
None
,
"The loss_name should be set here."
if
self
.
_places
is
not
None
:
if
not
isinstance
(
self
.
_places
,
(
list
,
tuple
)):
self
.
_places
=
[
self
.
_places
]
return
self
def
with_inference_optimize
(
self
,
config
):
...
...
@@ -238,10 +259,13 @@ class CompiledProgram(object):
def
_with_distributed
(
self
):
raise
NotImplementedError
()
def
_compile_data_parallel
(
self
,
use_cuda
=
False
,
scope
=
None
):
def
_compile_data_parallel
(
self
,
places
,
use_cuda
=
False
,
scope
=
None
):
if
self
.
_share_vars_from
:
if
scope
:
sys
.
stderr
.
write
(
"share_vars_from is set, scope is ignored.
\n
"
)
if
not
self
.
_is_data_parallel
:
raise
ValueError
(
"Currently, only data parallel mode need share_vars_from."
)
if
not
self
.
_share_vars_from
.
_is_data_parallel
:
raise
ValueError
(
"share_vars_from is not data parallel. Cannot "
"share vars from it."
)
...
...
@@ -254,24 +278,30 @@ class CompiledProgram(object):
assert
scope
is
not
None
,
""
self
.
_local_scopes
=
[]
assert
isinstance
(
places
,
tuple
)
or
isinstance
(
places
,
list
),
\
"Currently , The places type only should be list or tuple,
\n
"
\
"but the input type is {}."
.
format
(
type
(
places
))
if
self
.
_build_strategy
is
None
:
self
.
_build_strategy
=
BuildStrategy
()
self
.
_build_strategy
.
is_distribution
=
_is_pserver_mode
(
self
.
_program
)
if
self
.
_exec_strategy
is
None
:
self
.
_exec_strategy
=
ExecutionStrategy
()
self
.
_exec_strategy
.
use_cuda
=
use_cuda
has_set_place
=
(
self
.
_places
is
not
None
)
if
has_set_place
:
for
p
in
self
.
_places
:
assert
p
.
_type
()
==
self
.
_place
.
_type
(),
\
"Place type not match. You may set the wrong type of places"
else
:
self
.
_places
=
cuda_places
(
)
if
self
.
_exec_strategy
.
use_cuda
else
cpu_places
()
assert
self
.
_places
,
"no place for execution"
if
self
.
_exec_strategy
.
num_threads
==
0
:
if
self
.
_exec_strategy
.
use_cuda
:
# Experiments on se-resnext shows that too many threads hurt
# performance. Worth tunning for other models in the future.
self
.
_exec_strategy
.
num_threads
=
len
(
self
.
_
places
)
*
4
self
.
_exec_strategy
.
num_threads
=
len
(
places
)
*
4
else
:
self
.
_exec_strategy
.
num_threads
=
len
(
self
.
_places
)
*
2
self
.
_exec_strategy
.
num_threads
=
len
(
places
)
*
2
if
self
.
_build_strategy
.
num_trainers
>
1
:
assert
self
.
_is_data_parallel
,
\
"If you use multi-trainer to train the model, you should use "
\
"the data parallel model, i.e. calling with_data_parallel function."
# TODO(wuyi): trainer endpoings should be passed in through
# build_strategy, not program.xxx.
...
...
@@ -298,7 +328,8 @@ class CompiledProgram(object):
node
.
var
().
type
()
!=
core
.
VarDesc
.
VarType
.
RAW
:
self
.
_persistable_vars
.
append
(
cpt
.
to_text
(
node
.
name
()))
places
=
list
(
map
(
_place_obj
,
self
.
_places
))
places
=
list
(
map
(
_place_obj
,
places
))
# ParallelExecutor would broadcast all the parameters during initializing.
# The parameters of each process should be in the same ordered for the data-parallelism
# distributed training to keep the broadcast correct.
...
...
@@ -335,13 +366,28 @@ class CompiledProgram(object):
self
.
_scope
=
scope
self
.
_place
=
place
if
self
.
_is_data_parallel
:
self
.
_executor
=
self
.
_compile_data_parallel
(
use_cuda
=
isinstance
(
self
.
_place
,
core
.
CUDAPlace
),
scope
=
self
.
_scope
)
elif
self
.
_is_inference
:
if
self
.
_is_inference
:
self
.
_executor
=
self
.
_compile_inference
()
else
:
p
=
_place_obj
(
self
.
_place
)
self
.
_executor
=
core
.
Executor
(
p
)
if
self
.
_is_data_parallel
:
self
.
_places
=
self
.
_get_places
(
self
.
_place
,
self
.
_places
)
else
:
self
.
_places
=
[
self
.
_place
]
self
.
_executor
=
self
.
_compile_data_parallel
(
use_cuda
=
isinstance
(
self
.
_place
,
core
.
CUDAPlace
),
scope
=
self
.
_scope
,
places
=
self
.
_places
)
return
self
def
_get_places
(
self
,
place
,
place_list
):
has_set_place
=
(
place_list
is
not
None
)
if
has_set_place
:
for
p
in
place_list
:
assert
p
.
_type
()
==
place
.
_type
(),
\
"Place type not match. You may set the wrong type of places"
else
:
place_list
=
cuda_places
()
if
isinstance
(
place
,
core
.
CUDAPlace
)
else
cpu_places
()
assert
place_list
,
"no place for execution"
return
place_list
python/paddle/fluid/executor.py
浏览文件 @
4278518f
...
...
@@ -643,7 +643,6 @@ class Executor(object):
if
not
compiled
:
return
self
.
_run_program
(
program
,
self
.
_default_executor
,
feed
=
feed
,
fetch_list
=
fetch_list
,
feed_var_name
=
feed_var_name
,
...
...
@@ -653,7 +652,9 @@ class Executor(object):
use_program_cache
=
use_program_cache
)
program
.
_compile
(
scope
,
self
.
place
)
if
program
.
_is_data_parallel
:
if
program
.
_is_inference
:
return
self
.
_run_inference
(
program
.
_executor
,
feed
)
else
:
return
self
.
_run_parallel
(
program
,
scope
=
scope
,
...
...
@@ -661,26 +662,8 @@ class Executor(object):
fetch_list
=
fetch_list
,
fetch_var_name
=
fetch_var_name
,
return_numpy
=
return_numpy
)
elif
program
.
_is_inference
:
return
self
.
_run_inference
(
program
.
_executor
,
feed
)
else
:
# TODO(panyx0718): Can compile program to optimize executor
# performance.
# TODO(panyx0718): executor should be able to run graph.
assert
program
.
_program
,
"CompiledProgram is compiled from graph, can only run with_data_parallel."
# use_program_cache is not valid with CompiledProgram
return
self
.
_run_program
(
program
.
_program
,
self
.
_default_executor
,
feed
=
feed
,
fetch_list
=
fetch_list
,
feed_var_name
=
feed_var_name
,
fetch_var_name
=
fetch_var_name
,
scope
=
scope
,
return_numpy
=
return_numpy
,
use_program_cache
=
False
)
def
_run_program
(
self
,
program
,
exe
,
feed
,
fetch_list
,
feed_var_name
,
def
_run_program
(
self
,
program
,
feed
,
fetch_list
,
feed_var_name
,
fetch_var_name
,
scope
,
return_numpy
,
use_program_cache
):
if
feed
is
None
:
...
...
@@ -742,9 +725,11 @@ class Executor(object):
self
.
_feed_data
(
program
,
feed
,
feed_var_name
,
scope
)
if
not
use_program_cache
:
exe
.
run
(
program
.
desc
,
scope
,
0
,
True
,
True
,
fetch_var_name
)
self
.
_default_executor
.
run
(
program
.
desc
,
scope
,
0
,
True
,
True
,
fetch_var_name
)
else
:
exe
.
run_cached_prepared_ctx
(
ctx
,
scope
,
False
,
False
,
False
)
self
.
_default_executor
.
run_cached_prepared_ctx
(
ctx
,
scope
,
False
,
False
,
False
)
arr
=
scope
.
find_var
(
fetch_var_name
).
get_lod_tensor_array
()
tensors
=
arr
.
_move_to_list
()
if
return_numpy
:
...
...
python/paddle/fluid/io.py
浏览文件 @
4278518f
...
...
@@ -111,6 +111,20 @@ def _clone_var_in_block_(block, var):
persistable
=
True
)
def
_get_valid_program
(
main_program
):
if
main_program
is
None
:
main_program
=
default_main_program
()
elif
isinstance
(
main_program
,
CompiledProgram
):
main_program
=
main_program
.
_program
if
main_program
is
None
:
raise
TypeError
(
"program should be as Program type or None"
)
warnings
.
warn
(
"The input is a CompiledProgram, this is not recommended."
)
if
not
isinstance
(
main_program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
return
main_program
def
save_vars
(
executor
,
dirname
,
main_program
=
None
,
...
...
@@ -193,13 +207,9 @@ def save_vars(executor,
# saved in the same file named 'var_file' in the path "./my_paddle_vars".
"""
save_dirname
=
os
.
path
.
normpath
(
dirname
)
main_program
=
_get_valid_program
(
main_program
)
if
vars
is
None
:
if
main_program
is
None
:
main_program
=
default_main_program
()
if
not
isinstance
(
main_program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
save_vars
(
executor
,
main_program
=
main_program
,
...
...
@@ -210,11 +220,6 @@ def save_vars(executor,
save_program
=
Program
()
save_block
=
save_program
.
global_block
()
if
main_program
is
None
:
main_program
=
default_main_program
()
if
not
isinstance
(
main_program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
save_var_map
=
{}
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
...
...
@@ -516,11 +521,9 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.save_persistables(executor=exe, dirname=param_path,
main_program=prog)
"""
if
main_program
and
main_program
.
_is_distributed
:
_save_distributed_persistables
(
executor
,
dirname
=
dirname
,
main_program
=
main_program
)
else
:
save_vars
(
executor
,
...
...
@@ -1026,11 +1029,7 @@ def save_inference_model(dirname,
all
(
isinstance
(
var
,
Variable
)
for
var
in
target_vars
)):
raise
ValueError
(
"'target_vars' should be a list of Variable."
)
if
main_program
is
None
:
main_program
=
default_main_program
()
elif
not
isinstance
(
main_program
,
Program
):
raise
TypeError
(
"program should be as Program type or None"
)
main_program
=
_get_valid_program
(
main_program
)
# fix the bug that the activation op's output as target will be pruned.
# will affect the inference performance.
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
4278518f
...
...
@@ -88,7 +88,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
else
:
binary
=
compiler
.
CompiledProgram
(
main
)
binary
=
main
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
(
...
...
python/paddle/fluid/tests/unittests/test_eager_deletion_dynamic_rnn_base.py
浏览文件 @
4278518f
...
...
@@ -61,9 +61,10 @@ def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
fluid
.
default_main_program
().
random_seed
=
1
exe
.
run
(
fluid
.
default_startup_program
())
train_cp
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
()
)
train_cp
=
fluid
.
default_main_program
(
)
if
use_parallel_executor
:
train_cp
=
train_cp
.
with_data_parallel
(
loss_name
=
cost
.
name
)
train_cp
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
(
)).
with_data_parallel
(
loss_name
=
cost
.
name
)
fetch_list
=
[
cost
.
name
]
else
:
fetch_list
=
[
cost
]
...
...
python/paddle/fluid/tests/unittests/test_eager_deletion_recurrent_op.py
浏览文件 @
4278518f
...
...
@@ -192,13 +192,13 @@ class EagerDeletionRecurrentOpTest1(unittest.TestCase):
def
test_backward
(
self
,
rtol
=
0.01
):
self
.
check_forward
()
num_grad
=
self
.
get_numerical_gradient
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
append_backward
(
self
.
output
)
ana_grad
=
[
np
.
array
(
x
)
for
x
in
self
.
backward
()]
num_grad
=
self
.
get_numerical_gradient
()
for
idx
,
name
in
enumerate
(
self
.
data_field
):
self
.
assertEqual
(
num_grad
[
idx
].
shape
,
ana_grad
[
idx
].
shape
)
self
.
assertTrue
(
...
...
@@ -601,6 +601,7 @@ class EagerDeletionRecurrentOpParallelExecutorTest(
exec_strategy
=
fluid
.
ExecutionStrategy
()
parallel_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
loss_name
=
self
.
output
.
name
,
main_program
=
self
.
main_program
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
...
...
python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py
浏览文件 @
4278518f
...
...
@@ -128,9 +128,10 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase):
exe
=
Executor
(
self
.
place
)
exe
.
run
(
fluid
.
default_startup_program
())
prog
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
()
)
prog
=
fluid
.
default_main_program
(
)
if
self
.
with_data_parallel
:
prog
=
prog
.
with_data_parallel
()
prog
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
(
)).
with_data_parallel
(
loss_name
=
loss
.
name
)
for
_
in
range
(
5
):
d
=
[]
...
...
python/paddle/fluid/tests/unittests/test_inference_model_io.py
浏览文件 @
4278518f
...
...
@@ -137,8 +137,7 @@ class TestInstance(unittest.TestCase):
cp_prog
=
CompiledProgram
(
program
).
with_data_parallel
(
loss_name
=
avg_cost
.
name
)
self
.
assertRaises
(
TypeError
,
save_inference_model
,
[
MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
exe
,
cp_prog
])
save_inference_model
(
MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
exe
,
cp_prog
)
self
.
assertRaises
(
TypeError
,
save_inference_model
,
[
MODEL_DIR
,
[
"x"
,
"y"
],
[
avg_cost
],
[],
cp_prog
])
...
...
python/paddle/fluid/tests/unittests/test_py_func_op.py
浏览文件 @
4278518f
...
...
@@ -142,8 +142,15 @@ def test_main(use_cuda, use_py_func_op, use_parallel_executor):
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
train_cp
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
())
#FIXME force use old memory optimzie strategy here to pass the unittest
#since open the new strategy will crash the unittest
fluid
.
memory_optimize
(
fluid
.
default_main_program
())
train_cp
=
fluid
.
default_main_program
()
if
use_parallel_executor
:
train_cp
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
(
))
train_cp
=
train_cp
.
with_data_parallel
(
loss_name
=
loss
.
name
)
fetch_list
=
[
loss
.
name
]
else
:
...
...
python/paddle/fluid/tests/unittests/test_py_reader_using_executor.py
浏览文件 @
4278518f
...
...
@@ -214,9 +214,10 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
train_cp
=
compiler
.
CompiledProgram
(
main_program
)
train_cp
=
main_program
if
use_parallel_executor
:
train_cp
=
train_cp
.
with_data_parallel
(
loss_name
=
loss
.
name
)
train_cp
=
compiler
.
CompiledProgram
(
main_program
).
with_data_parallel
(
loss_name
=
loss
.
name
)
if
use_cuda
:
self
.
batch_size_times
=
core
.
get_cuda_device_count
()
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录