提交 604bd85a 编写于 作者: F fengjiayi

update inference_optimize()

上级 d7e08c53
...@@ -88,9 +88,8 @@ class BlockDesc { ...@@ -88,9 +88,8 @@ class BlockDesc {
OpDesc *InsertOp(size_t index); OpDesc *InsertOp(size_t index);
/* /*
* Remove Op and its input/output variables. * Only remove op itself,
* Note that for either input or output variable, if it is also an input or * do nothing to its input and output variables
* output variable of other ops, we should remain it.
*/ */
void RemoveOp(size_t s, size_t e); void RemoveOp(size_t s, size_t e);
......
...@@ -1540,7 +1540,12 @@ class Program(object): ...@@ -1540,7 +1540,12 @@ class Program(object):
def inference_optimize(self): def inference_optimize(self):
""" """
This method will create a new program and change the :code:`is_test` This method will create a new program and do following adjustments on it:
1. Remove all reader variables and their creator ops if exist.
2. Remove the :code:`read_op` if exists.
3. change the :code:`is_test`
attribute of operators to :code:`True`. All the :code:`Parameter` attribute of operators to :code:`True`. All the :code:`Parameter`
information will be lost. information will be lost.
...@@ -1554,6 +1559,22 @@ class Program(object): ...@@ -1554,6 +1559,22 @@ class Program(object):
# core.inference_optimize being fixed. # core.inference_optimize being fixed.
res = Program() res = Program()
res.desc = core.ProgramDesc(self.desc) res.desc = core.ProgramDesc(self.desc)
# remove all readers and the read_op if exist
read_op_idx = 0
root_block = res.desc.block(0)
while True:
if read_op_idx >= root_block.op_size() or root_block.op(
read_op_idx).type() == 'read':
break
read_op_idx += 1
if read_op_idx < root_block.op_size():
root_block._remove_op(0, read_op_idx + 1)
for var in root_block.all_vars():
if var.type() == core.VarDesc.VarType.READER:
root_block._remove_var(var.name())
# change all `is_test` attributes to True
for i in xrange(res.desc.num_blocks()): for i in xrange(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
for j in xrange(block.op_size()): for j in xrange(block.op_size()):
......
...@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -443,9 +443,6 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册