提交 7de8d115 编写于 作者: F fengjiayi

follow comments

上级 7876b213
...@@ -762,10 +762,10 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -762,10 +762,10 @@ def get_test_program(filelist, program=None, startup_program=None):
new_var.persistable = True new_var.persistable = True
return new_var return new_var
def get_test_reader_name(train_reader_name): def _get_test_reader_name(train_reader_name):
return train_reader_name + "_test" return train_reader_name + "_test"
def is_reader_op(op): def _is_reader_op(op):
block = op.block block = op.block
if "Out" in op.output_names: if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]] reader_out = block.vars[op.output("Out")[0]]
...@@ -783,7 +783,7 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -783,7 +783,7 @@ def get_test_program(filelist, program=None, startup_program=None):
startup_reader_op_list = [] startup_reader_op_list = []
for op in startup_block.ops: for op in startup_block.ops:
if is_reader_op(op): if _is_reader_op(op):
startup_reader_op_list.append(op) startup_reader_op_list.append(op)
if len(startup_reader_op_list) == 0: if len(startup_reader_op_list) == 0:
...@@ -799,7 +799,7 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -799,7 +799,7 @@ def get_test_program(filelist, program=None, startup_program=None):
test_reader = _copy_reader_var_( test_reader = _copy_reader_var_(
startup_block, startup_block,
train_reader, train_reader,
new_name=get_test_reader_name(train_reader_name)) new_name=_get_test_reader_name(train_reader_name))
train_test_reader_map[train_reader.name] = test_reader train_test_reader_map[train_reader.name] = test_reader
test_op_inputs = {} test_op_inputs = {}
...@@ -830,7 +830,7 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -830,7 +830,7 @@ def get_test_program(filelist, program=None, startup_program=None):
for var in main_block.vars.values(): for var in main_block.vars.values():
if var.type == core.VarDesc.VarType.READER: if var.type == core.VarDesc.VarType.READER:
main_block.rename_var( main_block.rename_var(
str(var.name), str(get_test_reader_name(var.name))) str(var.name), str(_get_test_reader_name(var.name)))
for op in main_block.ops: for op in main_block.ops:
if op.type == root_reader_op.type: if op.type == root_reader_op.type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册