提交 a424f5a0 编写于 作者: Y yi.wu

polish reader op for test

上级 343c1957
...@@ -259,7 +259,9 @@ def monkey_patch_reader_methods(reader): ...@@ -259,7 +259,9 @@ def monkey_patch_reader_methods(reader):
return reader return reader
def _copy_reader_var_(block, var): def _copy_reader_var_(block, var, newname=None):
if newname == None:
newname = var.name
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes()) new_var.desc.set_dtypes(var.desc.dtypes())
...@@ -691,68 +693,80 @@ def load(out, file_path, load_as_fp16=None): ...@@ -691,68 +693,80 @@ def load(out, file_path, load_as_fp16=None):
helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs) helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs)
def get_test_program(filelist, test_program=None, startup_program=None): def _is_reader_op(op, block):
""" if "Out" in op.output_names:
Transpile current program to read test dataset if the program reader_out = block.vars[op.output("Out")[0]]
is using reader ops like "open_files_op". if reader_out.type == core.VarDesc.VarType.READER:
return True
return False
Args:
filelist (list): list of test file paths.
test_program (Program|None): program to run test/evaluation.
default use fluid.default_main_program()
startup_program (Program|None): startup program to change,
default use fluid.default_startup_program()
Returns: def get_test_program(filelist, program=None, startup_program=None):
Program: program for test
""" """
if test_program == None: Transpile current train program to a program to read test dataset
if the program is using reader ops like "open_files_op".
"""
if program == None:
program = default_main_program() program = default_main_program()
if startup_program == None: if startup_program == None:
startup_program = default_startup_program() startup_program = default_startup_program()
# 1. find out the orignal reader var name # 1. find out the orignal reader var name
open_files_var = None # open_files_var = None
train_open_files_op = None # train_open_files_op = None
startup_reader_op_list = []
for op in startup_program.global_block().ops: for op in startup_program.global_block().ops:
if op.type == "open_files": if _is_reader_op(op, startup_program.global_block()):
train_open_files_op = op startup_reader_op_list.append(op)
open_files_var_name = op.output("Out")[0]
open_files_var = startup_program.global_block().vars[ if len(startup_reader_op_list) == 0:
open_files_var_name] return program
# 2. add operator to startup to read open and read test data files root_reader_op = startup_reader_op_list[0]
test_startup_var = startup_program.global_block().create_var(
name=open_files_var.name + "_test") # 2. add operators to startup to read open and read test data files
for op in startup_reader_op_list:
print("creating openfiles for test reader: ", train_open_files_op.attrs) orig_var_name = op.output("Out")[0]
startup_program.global_block().append_op( orig_var = startup_program.global_block().vars[orig_var_name]
type='open_files', new_test_var = _copy_reader_var_(
outputs={'Out': [test_startup_var]}, startup_program.global_block(),
attrs={ orig_var,
'shape_concat': train_open_files_op.attrs["shape_concat"], newname=orig_var_name + "_test")
'lod_levels': train_open_files_op.attrs["lod_levels"],
'ranks': train_open_files_op.attrs["ranks"], # for open_files like operators have no input.
'file_names': filelist, inputs = None
'thread_num': train_open_files_op.attrs["thread_num"], if "UnderlyingReader" in op.input_names:
'buffer_size': train_open_files_op.attrs["buffer_size"] orig_input_var_name = op.input("UnderlyingReader")[0]
}) orig_input_var = startup_program.global_block().vars[
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in ["float32", "int64"]] orig_input_var_name]
test_startup_var.desc.set_dtypes(dtypes) new_input_var = _copy_reader_var_(
test_startup_var.persistable = True startup_program.global_block(),
_copy_reader_var_(default_main_program().global_block(), test_startup_var) orig_input_var,
newname=orig_input_var_name + "_test")
inputs = {"UnderlyingReader": new_input_var}
test_op = startup_program.global_block().append_op(
type=op.type,
inputs=inputs,
outputs={'Out': [new_test_var]},
attrs=op.attrs)
# root reader op's filelist attr for read test files
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1)
# 3. rename reader vars in inference program to different name # 3. rename reader vars in inference program to different name
# to avoid read from train data. # to avoid read from train data.
program.global_block().rename_var(open_files_var.name, origname = root_reader_op.output("Out")[0]
test_startup_var.name) newname = origname + "_test"
program.global_block().rename_var(str(origname), str(newname))
for op in program.global_block().ops: for op in program.global_block().ops:
if "Out" in op.output_names: if _is_reader_op(op, program.global_block()):
op_out_var_name = op.output("Out")[0] origname = op.output("Out")[0]
op_out_var = program.global_block().vars[op_out_var_name] newname = origname + "_test"
if op_out_var.type == core.VarDesc.VarType.READER: program.global_block().rename_var(str(origname), str(newname))
newname = op_out_var.name + "_test"
program.global_block().rename_var(op_out_var.name, newname)
if op.type == "create_multi_pass_reader": if op.type == "create_multi_pass_reader":
op.set_attr("pass_num", 1) op.set_attr("pass_num", 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册