提交 0c7d6eb8 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into port_python3_syntax

...@@ -88,9 +88,8 @@ class BlockDesc { ...@@ -88,9 +88,8 @@ class BlockDesc {
OpDesc *InsertOp(size_t index); OpDesc *InsertOp(size_t index);
/* /*
* Remove Op and its input/output variables. * Only remove op itself,
* Note that for either input or output variable, if it is also an input or * do nothing to its input and output variables
* output variable of other ops, we should remain it.
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
......
...@@ -59,13 +59,12 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) { ...@@ -59,13 +59,12 @@ inline const char* cudnnGetErrorString(cudnnStatus_t status) {
#define CUDNN_VERSION_MIN(major, minor, patch) \ #define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
#define CUDNN_ENFORCE(condition) \ #define CUDNN_ENFORCE(condition) \
do { \ do { \
cudnnStatus_t status = condition; \ cudnnStatus_t status = condition; \
if (status != CUDNN_STATUS_SUCCESS) { \ if (UNLIKELY(status != CUDNN_STATUS_SUCCESS)) { \
VLOG(1) << ::paddle::platform::cudnnGetErrorString(status); \ PADDLE_THROW(::paddle::platform::cudnnGetErrorString(status)); \
PADDLE_THROW("cuDNN call failed"); \ } \
} \
} while (false) } while (false)
enum class DataLayout { // Not use enum class DataLayout { // Not use
......
...@@ -1555,7 +1555,12 @@ class Program(object): ...@@ -1555,7 +1555,12 @@ class Program(object):
def inference_optimize(self): def inference_optimize(self):
""" """
This method will create a new program and change the :code:`is_test` This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist.
2. Remove the :code:`read_op` if exists.
3. change the :code:`is_test`
attribute of operators to :code:`True`. All the :code:`Parameter` attribute of operators to :code:`True`. All the :code:`Parameter`
information will be lost. information will be lost.
...@@ -1569,6 +1574,22 @@ class Program(object): ...@@ -1569,6 +1574,22 @@ class Program(object):
# core.inference_optimize being fixed. # core.inference_optimize being fixed.
res = Program() res = Program()
res.desc = core.ProgramDesc(self.desc) res.desc = core.ProgramDesc(self.desc)
# remove all readers and the read_op if exist
read_op_idx = 0
root_block = res.desc.block(0)
while True:
if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read':
break
read_op_idx += 1
if read_op_idx < root_block.op_size():
root_block._remove_op(0, read_op_idx + 1)
for var in root_block.all_vars():
if var.type() == core.VarDesc.VarType.READER:
root_block._remove_var(var.name())
# change all `is_test` attributes to True
for i in range(res.desc.num_blocks()): for i in range(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
for j in range(block.op_size()): for j in range(block.op_size()):
......
...@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册