diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 7d10f12a54e2ec1b00ca20308358556bd43d5c71..15aa12d5d738350abbf9dcb84d4c9feeed4173e2 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -259,10 +259,11 @@ def monkey_patch_reader_methods(reader): return reader -def _copy_reader_var_(block, var, newname=None): - if newname == None: - newname = var.name - new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER) +def _copy_reader_var_(block, var, new_name=None): + if new_name == None: + new_name = var.name + new_var = block.create_var( + name=str(new_name), type=core.VarDesc.VarType.READER) new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_dtypes(var.desc.dtypes()) new_var.persistable = True @@ -693,62 +694,67 @@ def load(out, file_path, load_as_fp16=None): helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs) -def _is_reader_op(op, block): - if "Out" in op.output_names: - reader_out = block.vars[op.output("Out")[0]] - if reader_out.type == core.VarDesc.VarType.READER: - return True - return False - - def get_test_program(filelist, program=None, startup_program=None): """ Transpile current train program to a program to read test dataset if the program is using reader ops like "open_files_op". """ + + def get_test_reader_name(train_reader_name): + return train_reader_name + "_test" + + def is_reader_op(op): + block = op.block + if "Out" in op.output_names: + reader_out = block.vars[op.output("Out")[0]] + if reader_out.type == core.VarDesc.VarType.READER: + return True + return False + if program == None: program = default_main_program() if startup_program == None: startup_program = default_startup_program() + startup_block = startup_program.global_block() # 1. find out the orignal reader var name - # open_files_var = None - # train_open_files_op = None startup_reader_op_list = [] - for op in startup_program.global_block().ops: - if _is_reader_op(op, startup_program.global_block()): + for op in startup_block.ops: + if is_reader_op(op): startup_reader_op_list.append(op) if len(startup_reader_op_list) == 0: return program root_reader_op = startup_reader_op_list[0] - + train_test_reader_map = {} # 2. add operators to startup to read open and read test data files for op in startup_reader_op_list: - orig_var_name = op.output("Out")[0] - orig_var = startup_program.global_block().vars[orig_var_name] - new_test_var = _copy_reader_var_( - startup_program.global_block(), - orig_var, - newname=orig_var_name + "_test") - - # for open_files like operators have no input. - inputs = None - if "UnderlyingReader" in op.input_names: - orig_input_var_name = op.input("UnderlyingReader")[0] - orig_input_var = startup_program.global_block().vars[ - orig_input_var_name] - new_input_var = _copy_reader_var_( - startup_program.global_block(), - orig_input_var, - newname=orig_input_var_name + "_test") - inputs = {"UnderlyingReader": new_input_var} - test_op = startup_program.global_block().append_op( + assert (len(op.output("Out")) == 1) + train_reader_name = op.output("Out")[0] + train_reader = startup_block.vars[train_reader_name] + test_reader = _copy_reader_var_( + startup_block, + train_reader, + new_name=get_test_reader_name(train_reader_name)) + train_test_reader_map[train_reader.name] = test_reader + + test_op_inputs = {} + for name in op.input_names: + train_arg_names = op.input(name) + test_arg_vars = [] + for arg_name in train_arg_names: + arg_var = train_test_reader_map[ + arg_name] if name == "UnderlyingReader" else startup_block.vars[ + arg_name] + test_arg_vars.append(arg_var) + test_op_inputs[name] = test_arg_vars + + test_op = startup_block.append_op( type=op.type, - inputs=inputs, - outputs={'Out': [new_test_var]}, + inputs=test_op_inputs, + outputs={'Out': [test_reader]}, attrs=op.attrs) # root reader op's filelist attr for read test files if op.type == root_reader_op.type: @@ -758,18 +764,19 @@ def get_test_program(filelist, program=None, startup_program=None): # 3. rename reader vars in inference program to different name # to avoid read from train data. - origname = root_reader_op.output("Out")[0] - newname = origname + "_test" - program.global_block().rename_var(str(origname), str(newname)) - for op in program.global_block().ops: - if _is_reader_op(op, program.global_block()): - origname = op.output("Out")[0] - newname = origname + "_test" - program.global_block().rename_var(str(origname), str(newname)) + main_block = program.global_block() + for var in main_block.vars.values(): + if var.type == core.VarDesc.VarType.READER: + main_block.rename_var( + str(var.name), str(get_test_reader_name(var.name))) + for op in main_block.ops: + if op.type == root_reader_op.type: + test_op.set_attr("file_names", filelist) if op.type == "create_multi_pass_reader": - op.set_attr("pass_num", 1) + test_op.set_attr("pass_num", 1) + startup_program.sync_with_cpp() program.sync_with_cpp() return program