提交 a9417863 编写于 作者: Q qiaolongfei

replace concat_op with merge_ids_op

上级 509cb0bc
...@@ -618,7 +618,7 @@ class DistributeTranspiler: ...@@ -618,7 +618,7 @@ class DistributeTranspiler:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
op_index = list(all_ops).index(op) lookup_table_op_index = list(all_ops).index(op)
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
...@@ -637,7 +637,7 @@ class DistributeTranspiler: ...@@ -637,7 +637,7 @@ class DistributeTranspiler:
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index, index=lookup_table_op_index,
type="split_ids", type="split_ids",
inputs={ inputs={
'Ids': [ 'Ids': [
...@@ -649,7 +649,7 @@ class DistributeTranspiler: ...@@ -649,7 +649,7 @@ class DistributeTranspiler:
# insert prefetch_op # insert prefetch_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 1, index=lookup_table_op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': self.prefetch_input_vars}, inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars}, outputs={"Out": self.prefetch_output_vars},
...@@ -660,16 +660,21 @@ class DistributeTranspiler: ...@@ -660,16 +660,21 @@ class DistributeTranspiler:
# insert concat_op # insert concat_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=lookup_table_op_index + 2,
type="concat", type="merge_ids",
inputs={'X': self.prefetch_output_vars}, inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
],
'X': self.prefetch_output_vars
},
outputs={ outputs={
"Out": [ "Out": [
program.global_block().vars[varname] program.global_block().vars[varname]
for varname in out_name for varname in out_name
] ]
}, })
attrs={"axis": 0})
# delete lookup_table_op # delete lookup_table_op
delete_ops(program.global_block(), [op]) delete_ops(program.global_block(), [op])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册