diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 629ded7f7a6e27c16093c479f523384e864a7e15..ac91c367962d03299f7315cf65f7187ea05775e5 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -574,11 +574,28 @@ def load_persist_vars_without_grad(executor, filename=None) -def load_lookup_table_vars(executor, dirname, pserver_id, table_name): +def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name): + + for var in program.list_vars(): + if var.name == table_name: + lookup_table_var = var + break + + assert lookup_table_var is not None + lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) - table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) + table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) + + load_prog = Program() + load_block = load_prog.global_block() + + load_block.append_op( + type='load', + inputs={}, + outputs={'Out': [lookup_table_var]}, + attrs={'file_path': os.path.join(lookup_table_dir, table_file)}) - load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) + executor.run(load_prog) def save_persist_vars_without_grad(executor, dirname, program):