diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 27992df462ffd00ddf445538cc508b4232712481..ed4158bc4c91b4db0143c8d607a5ea9220e528ff 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -618,7 +618,7 @@ class DistributeTranspiler: if op.type == LOOKUP_TABLE_TYPE: continue_search_lookup_table_op = True - op_index = list(all_ops).index(op) + lookup_table_op_index = list(all_ops).index(op) ids_name = op.input("Ids") out_name = op.output("Out") @@ -637,7 +637,7 @@ class DistributeTranspiler: # insert split_ids_op program.global_block().insert_op( - index=op_index, + index=lookup_table_op_index, type="split_ids", inputs={ 'Ids': [ @@ -649,7 +649,7 @@ class DistributeTranspiler: # insert prefetch_op program.global_block().insert_op( - index=op_index + 1, + index=lookup_table_op_index + 1, type="prefetch", inputs={'X': self.prefetch_input_vars}, outputs={"Out": self.prefetch_output_vars}, @@ -660,16 +660,21 @@ class DistributeTranspiler: # insert concat_op program.global_block().insert_op( - index=op_index + 2, - type="concat", - inputs={'X': self.prefetch_output_vars}, + index=lookup_table_op_index + 2, + type="merge_ids", + inputs={ + 'Ids': [ + program.global_block().vars[varname] + for varname in ids_name + ], + 'X': self.prefetch_output_vars + }, outputs={ "Out": [ program.global_block().vars[varname] for varname in out_name ] - }, - attrs={"axis": 0}) + }) # delete lookup_table_op delete_ops(program.global_block(), [op])