diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 29357f53c574b42de31f68fea7b5d624f502ad04..5826db292b2b9c36a8546041460c2eef13bd4821 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -477,12 +477,23 @@ class DistributeTranspiler(object): # delete table init op if self.has_distributed_lookup_table: - trainer_table_param_init_op = [] + table_var = self.startup_program.global_block().vars[ + self.table_name] + table_param_init_op = [] for op in self.startup_program.global_block().ops: if self.table_name in op.output_arg_names: - trainer_table_param_init_op.append(op) - delete_ops(self.startup_program.global_block(), - trainer_table_param_init_op) + table_param_init_op.append(op) + init_op_num = len(table_param_init_op) + if init_op_num != 1: + raise ValueError("table init op num should be 1, now is " + str( + init_op_num)) + table_init_op = table_param_init_op[1] + self.startup_program.global_block().append_op( + type="fake_init", + inputs={}, + outputs={"Out": table_var}, + attrs={"shape": table_init_op.attr('shape')}) + delete_ops(self.startup_program.global_block(), table_param_init_op) self.origin_program.__str__()