diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2192139f8d5950286691a77333dd8ec35505b033..fb2dd942fc379c564fbd96606eba03d904695109 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -474,6 +474,15 @@ class DistributeTranspiler(object): delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), lr_ops) + # delete table init op + if self.has_distributed_lookup_table: + trainer_table_param_init_op = [] + for op in self.startup_program.global_block().ops: + if self.table_name in op.output_arg_names: + trainer_table_param_init_op.append(op) + delete_ops(self.startup_program.global_block(), + trainer_table_param_init_op) + self.origin_program.__str__() if wait_port: @@ -1194,9 +1203,8 @@ to transpile() call.") # create table param and grad var in pserver program # create table optimize block in pserver program table_opt_op = [ - op for op in self.optimize_ops - if 'Param' in op.input_names and op.input("Param")[0] == - self.table_name + op for op in self.optimize_ops if 'Param' in op.input_names and + op.input("Param")[0] == self.table_name ][0] origin_param_var = self.origin_program.global_block().vars[