From a941786393e5b840a98215d81ba09c6cf0995ca1 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 11 Jun 2018 09:40:19 +0800 Subject: [PATCH] replace concat_op with merge_ids_op --- .../fluid/transpiler/distribute_transpiler.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 27992df462..ed4158bc4c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -618,7 +618,7 @@ class DistributeTranspiler: if op.type == LOOKUP_TABLE_TYPE: continue_search_lookup_table_op = True - op_index = list(all_ops).index(op) + lookup_table_op_index = list(all_ops).index(op) ids_name = op.input("Ids") out_name = op.output("Out") @@ -637,7 +637,7 @@ class DistributeTranspiler: # insert split_ids_op program.global_block().insert_op( - index=op_index, + index=lookup_table_op_index, type="split_ids", inputs={ 'Ids': [ @@ -649,7 +649,7 @@ class DistributeTranspiler: # insert prefetch_op program.global_block().insert_op( - index=op_index + 1, + index=lookup_table_op_index + 1, type="prefetch", inputs={'X': self.prefetch_input_vars}, outputs={"Out": self.prefetch_output_vars}, @@ -660,16 +660,21 @@ class DistributeTranspiler: # insert concat_op program.global_block().insert_op( - index=op_index + 2, - type="concat", - inputs={'X': self.prefetch_output_vars}, + index=lookup_table_op_index + 2, + type="merge_ids", + inputs={ + 'Ids': [ + program.global_block().vars[varname] + for varname in ids_name + ], + 'X': self.prefetch_output_vars + }, outputs={ "Out": [ program.global_block().vars[varname] for varname in out_name ] - }, - attrs={"axis": 0}) + }) # delete lookup_table_op delete_ops(program.global_block(), [op]) -- GitLab