提交 d52fcaf4 编写于 作者: Q Qiao Longfei

replace table init op with fake init

上级 0328ffd3
......@@ -477,12 +477,23 @@ class DistributeTranspiler(object):
# delete table init op
if self.has_distributed_lookup_table:
trainer_table_param_init_op = []
table_var = self.startup_program.global_block().vars[
self.table_name]
table_param_init_op = []
for op in self.startup_program.global_block().ops:
if self.table_name in op.output_arg_names:
trainer_table_param_init_op.append(op)
delete_ops(self.startup_program.global_block(),
trainer_table_param_init_op)
table_param_init_op.append(op)
init_op_num = len(table_param_init_op)
if init_op_num != 1:
raise ValueError("table init op num should be 1, now is " + str(
init_op_num))
table_init_op = table_param_init_op[1]
self.startup_program.global_block().append_op(
type="fake_init",
inputs={},
outputs={"Out": table_var},
attrs={"shape": table_init_op.attr('shape')})
delete_ops(self.startup_program.global_block(), table_param_init_op)
self.origin_program.__str__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册