提交 7de8d115 编写于 作者: F fengjiayi

follow comments

上级 7876b213
......@@ -762,10 +762,10 @@ def get_test_program(filelist, program=None, startup_program=None):
new_var.persistable = True
return new_var
def get_test_reader_name(train_reader_name):
def _get_test_reader_name(train_reader_name):
return train_reader_name + "_test"
def is_reader_op(op):
def _is_reader_op(op):
block = op.block
if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]]
......@@ -783,7 +783,7 @@ def get_test_program(filelist, program=None, startup_program=None):
startup_reader_op_list = []
for op in startup_block.ops:
if is_reader_op(op):
if _is_reader_op(op):
startup_reader_op_list.append(op)
if len(startup_reader_op_list) == 0:
......@@ -799,7 +799,7 @@ def get_test_program(filelist, program=None, startup_program=None):
test_reader = _copy_reader_var_(
startup_block,
train_reader,
new_name=get_test_reader_name(train_reader_name))
new_name=_get_test_reader_name(train_reader_name))
train_test_reader_map[train_reader.name] = test_reader
test_op_inputs = {}
......@@ -830,7 +830,7 @@ def get_test_program(filelist, program=None, startup_program=None):
for var in main_block.vars.values():
if var.type == core.VarDesc.VarType.READER:
main_block.rename_var(
str(var.name), str(get_test_reader_name(var.name)))
str(var.name), str(_get_test_reader_name(var.name)))
for op in main_block.ops:
if op.type == root_reader_op.type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册