提交 a9417863 编写于 作者: Q qiaolongfei

replace concat_op with merge_ids_op

上级 509cb0bc
......@@ -618,7 +618,7 @@ class DistributeTranspiler:
if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True
op_index = list(all_ops).index(op)
lookup_table_op_index = list(all_ops).index(op)
ids_name = op.input("Ids")
out_name = op.output("Out")
......@@ -637,7 +637,7 @@ class DistributeTranspiler:
# insert split_ids_op
program.global_block().insert_op(
index=op_index,
index=lookup_table_op_index,
type="split_ids",
inputs={
'Ids': [
......@@ -649,7 +649,7 @@ class DistributeTranspiler:
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
index=lookup_table_op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars},
......@@ -660,16 +660,21 @@ class DistributeTranspiler:
# insert concat_op
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': self.prefetch_output_vars},
index=lookup_table_op_index + 2,
type="merge_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
],
'X': self.prefetch_output_vars
},
outputs={
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
},
attrs={"axis": 0})
})
# delete lookup_table_op
delete_ops(program.global_block(), [op])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册