diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 6323c9899e0080b436a52f852c647466b8f94bc1..954eb0ea62081b0c39f8a4781a44d3a74b62a623 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -17,7 +17,7 @@ import time import shutil from paddle.fluid.evaluator import Evaluator -from paddle.fluid.framework import Program, Parameter, default_main_program, Variable +from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable from . import core __all__ = [ @@ -744,3 +744,101 @@ def get_latest_checkpoint_serial(checkpoint_dir): if success_num > current_dir: current_dir = success_num return current_dir + + +def get_test_program(filelist, program=None, startup_program=None): + """ + Transpile current train program to a program to read test dataset + if the program is using reader ops like "open_files_op". + """ + + def _copy_reader_var_(block, var, new_name=None): + if new_name == None: + new_name = var.name + new_var = block.create_var( + name=str(new_name), type=core.VarDesc.VarType.READER) + new_var.desc.set_shapes(var.desc.shapes()) + new_var.desc.set_dtypes(var.desc.dtypes()) + new_var.persistable = True + return new_var + + def get_test_reader_name(train_reader_name): + return train_reader_name + "_test" + + def is_reader_op(op): + block = op.block + if "Out" in op.output_names: + reader_out = block.vars[op.output("Out")[0]] + if reader_out.type == core.VarDesc.VarType.READER: + return True + return False + + if program == None: + program = default_main_program() + if startup_program == None: + startup_program = default_startup_program() + startup_block = startup_program.global_block() + + # 1. find out the orignal reader var name + startup_reader_op_list = [] + + for op in startup_block.ops: + if is_reader_op(op): + startup_reader_op_list.append(op) + + if len(startup_reader_op_list) == 0: + return program + + root_reader_op = startup_reader_op_list[0] + train_test_reader_map = {} + # 2. add operators to startup to read open and read test data files + for op in startup_reader_op_list: + assert (len(op.output("Out")) == 1) + train_reader_name = op.output("Out")[0] + train_reader = startup_block.vars[train_reader_name] + test_reader = _copy_reader_var_( + startup_block, + train_reader, + new_name=get_test_reader_name(train_reader_name)) + train_test_reader_map[train_reader.name] = test_reader + + test_op_inputs = {} + for name in op.input_names: + train_arg_names = op.input(name) + test_arg_vars = [] + for arg_name in train_arg_names: + arg_var = train_test_reader_map[ + arg_name] if name == "UnderlyingReader" else startup_block.vars[ + arg_name] + test_arg_vars.append(arg_var) + test_op_inputs[name] = test_arg_vars + + test_op = startup_block.append_op( + type=op.type, + inputs=test_op_inputs, + outputs={'Out': [test_reader]}, + attrs=op.attrs) + # root reader op's filelist attr for read test files + if op.type == root_reader_op.type: + test_op.set_attr("file_names", filelist) + if op.type == "create_multi_pass_reader": + test_op.set_attr("pass_num", 1) + + # 3. rename reader vars in inference program to different name + # to avoid read from train data. + main_block = program.global_block() + for var in main_block.vars.values(): + if var.type == core.VarDesc.VarType.READER: + main_block.rename_var( + str(var.name), str(get_test_reader_name(var.name))) + + for op in main_block.ops: + if op.type == root_reader_op.type: + test_op.set_attr("file_names", filelist) + if op.type == "create_multi_pass_reader": + test_op.set_attr("pass_num", 1) + + startup_program.sync_with_cpp() + program.sync_with_cpp() + + return program diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 15aa12d5d738350abbf9dcb84d4c9feeed4173e2..49ccfa92920927c5253989d9102e7e5129175953 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -692,91 +692,3 @@ def load(out, file_path, load_as_fp16=None): if load_as_fp16 is not None: attrs['load_as_fp16'] = load_as_fp16 helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs) - - -def get_test_program(filelist, program=None, startup_program=None): - """ - Transpile current train program to a program to read test dataset - if the program is using reader ops like "open_files_op". - """ - - def get_test_reader_name(train_reader_name): - return train_reader_name + "_test" - - def is_reader_op(op): - block = op.block - if "Out" in op.output_names: - reader_out = block.vars[op.output("Out")[0]] - if reader_out.type == core.VarDesc.VarType.READER: - return True - return False - - if program == None: - program = default_main_program() - if startup_program == None: - startup_program = default_startup_program() - startup_block = startup_program.global_block() - - # 1. find out the orignal reader var name - startup_reader_op_list = [] - - for op in startup_block.ops: - if is_reader_op(op): - startup_reader_op_list.append(op) - - if len(startup_reader_op_list) == 0: - return program - - root_reader_op = startup_reader_op_list[0] - train_test_reader_map = {} - # 2. add operators to startup to read open and read test data files - for op in startup_reader_op_list: - assert (len(op.output("Out")) == 1) - train_reader_name = op.output("Out")[0] - train_reader = startup_block.vars[train_reader_name] - test_reader = _copy_reader_var_( - startup_block, - train_reader, - new_name=get_test_reader_name(train_reader_name)) - train_test_reader_map[train_reader.name] = test_reader - - test_op_inputs = {} - for name in op.input_names: - train_arg_names = op.input(name) - test_arg_vars = [] - for arg_name in train_arg_names: - arg_var = train_test_reader_map[ - arg_name] if name == "UnderlyingReader" else startup_block.vars[ - arg_name] - test_arg_vars.append(arg_var) - test_op_inputs[name] = test_arg_vars - - test_op = startup_block.append_op( - type=op.type, - inputs=test_op_inputs, - outputs={'Out': [test_reader]}, - attrs=op.attrs) - # root reader op's filelist attr for read test files - if op.type == root_reader_op.type: - test_op.set_attr("file_names", filelist) - if op.type == "create_multi_pass_reader": - test_op.set_attr("pass_num", 1) - - # 3. rename reader vars in inference program to different name - # to avoid read from train data. - main_block = program.global_block() - for var in main_block.vars.values(): - if var.type == core.VarDesc.VarType.READER: - main_block.rename_var( - str(var.name), str(get_test_reader_name(var.name))) - - for op in main_block.ops: - if op.type == root_reader_op.type: - test_op.set_attr("file_names", filelist) - if op.type == "create_multi_pass_reader": - test_op.set_attr("pass_num", 1) - - startup_program.sync_with_cpp() - program.sync_with_cpp() - - return program