提交 7d3d7226 编写于 作者: F fengjiayi

refine get_test_program

上级 a424f5a0
...@@ -259,10 +259,11 @@ def monkey_patch_reader_methods(reader): ...@@ -259,10 +259,11 @@ def monkey_patch_reader_methods(reader):
return reader return reader
def _copy_reader_var_(block, var, newname=None): def _copy_reader_var_(block, var, new_name=None):
if newname == None: if new_name == None:
newname = var.name new_name = var.name
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) new_var = block.create_var(
name=str(new_name), type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes()) new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True new_var.persistable = True
...@@ -693,62 +694,67 @@ def load(out, file_path, load_as_fp16=None): ...@@ -693,62 +694,67 @@ def load(out, file_path, load_as_fp16=None):
helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs) helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs)
def _is_reader_op(op, block):
if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]]
if reader_out.type == core.VarDesc.VarType.READER:
return True
return False
def get_test_program(filelist, program=None, startup_program=None): def get_test_program(filelist, program=None, startup_program=None):
""" """
Transpile current train program to a program to read test dataset Transpile current train program to a program to read test dataset
if the program is using reader ops like "open_files_op". if the program is using reader ops like "open_files_op".
""" """
def get_test_reader_name(train_reader_name):
return train_reader_name + "_test"
def is_reader_op(op):
block = op.block
if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]]
if reader_out.type == core.VarDesc.VarType.READER:
return True
return False
if program == None: if program == None:
program = default_main_program() program = default_main_program()
if startup_program == None: if startup_program == None:
startup_program = default_startup_program() startup_program = default_startup_program()
startup_block = startup_program.global_block()
# 1. find out the orignal reader var name # 1. find out the orignal reader var name
# open_files_var = None
# train_open_files_op = None
startup_reader_op_list = [] startup_reader_op_list = []
for op in startup_program.global_block().ops: for op in startup_block.ops:
if _is_reader_op(op, startup_program.global_block()): if is_reader_op(op):
startup_reader_op_list.append(op) startup_reader_op_list.append(op)
if len(startup_reader_op_list) == 0: if len(startup_reader_op_list) == 0:
return program return program
root_reader_op = startup_reader_op_list[0] root_reader_op = startup_reader_op_list[0]
train_test_reader_map = {}
# 2. add operators to startup to read open and read test data files # 2. add operators to startup to read open and read test data files
for op in startup_reader_op_list: for op in startup_reader_op_list:
orig_var_name = op.output("Out")[0] assert (len(op.output("Out")) == 1)
orig_var = startup_program.global_block().vars[orig_var_name] train_reader_name = op.output("Out")[0]
new_test_var = _copy_reader_var_( train_reader = startup_block.vars[train_reader_name]
startup_program.global_block(), test_reader = _copy_reader_var_(
orig_var, startup_block,
newname=orig_var_name + "_test") train_reader,
new_name=get_test_reader_name(train_reader_name))
# for open_files like operators have no input. train_test_reader_map[train_reader.name] = test_reader
inputs = None
if "UnderlyingReader" in op.input_names: test_op_inputs = {}
orig_input_var_name = op.input("UnderlyingReader")[0] for name in op.input_names:
orig_input_var = startup_program.global_block().vars[ train_arg_names = op.input(name)
orig_input_var_name] test_arg_vars = []
new_input_var = _copy_reader_var_( for arg_name in train_arg_names:
startup_program.global_block(), arg_var = train_test_reader_map[
orig_input_var, arg_name] if name == "UnderlyingReader" else startup_block.vars[
newname=orig_input_var_name + "_test") arg_name]
inputs = {"UnderlyingReader": new_input_var} test_arg_vars.append(arg_var)
test_op = startup_program.global_block().append_op( test_op_inputs[name] = test_arg_vars
test_op = startup_block.append_op(
type=op.type, type=op.type,
inputs=inputs, inputs=test_op_inputs,
outputs={'Out': [new_test_var]}, outputs={'Out': [test_reader]},
attrs=op.attrs) attrs=op.attrs)
# root reader op's filelist attr for read test files # root reader op's filelist attr for read test files
if op.type == root_reader_op.type: if op.type == root_reader_op.type:
...@@ -758,18 +764,19 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -758,18 +764,19 @@ def get_test_program(filelist, program=None, startup_program=None):
# 3. rename reader vars in inference program to different name # 3. rename reader vars in inference program to different name
# to avoid read from train data. # to avoid read from train data.
origname = root_reader_op.output("Out")[0] main_block = program.global_block()
newname = origname + "_test" for var in main_block.vars.values():
program.global_block().rename_var(str(origname), str(newname)) if var.type == core.VarDesc.VarType.READER:
for op in program.global_block().ops: main_block.rename_var(
if _is_reader_op(op, program.global_block()): str(var.name), str(get_test_reader_name(var.name)))
origname = op.output("Out")[0]
newname = origname + "_test"
program.global_block().rename_var(str(origname), str(newname))
for op in main_block.ops:
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader": if op.type == "create_multi_pass_reader":
op.set_attr("pass_num", 1) test_op.set_attr("pass_num", 1)
startup_program.sync_with_cpp()
program.sync_with_cpp() program.sync_with_cpp()
return program return program
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册