提交 bccf8df5 编写于 作者: T tangwei12

bug fix

上级 5250ca8c
...@@ -574,11 +574,28 @@ def load_persist_vars_without_grad(executor, ...@@ -574,11 +574,28 @@ def load_persist_vars_without_grad(executor,
filename=None) filename=None)
def load_lookup_table_vars(executor, dirname, pserver_id, table_name): def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
for var in program.list_vars():
if var.name == table_name:
lookup_table_var = var
break
assert lookup_table_var is not None
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR) lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id) table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id)
load_prog = Program()
load_block = load_prog.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [lookup_table_var]},
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
load_vars(executor, lookup_table_dir, vars=table_name, filename=table_file) executor.run(load_prog)
def save_persist_vars_without_grad(executor, dirname, program): def save_persist_vars_without_grad(executor, dirname, program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册