提交 055662f9 编写于 作者: Q qiaolongfei

add insert_op for block

上级 b1a5a3ca
...@@ -659,7 +659,7 @@ class Block(object): ...@@ -659,7 +659,7 @@ class Block(object):
def __init__(self, program, idx): def __init__(self, program, idx):
self.desc = program.desc.block(idx) self.desc = program.desc.block(idx)
self.vars = dict() # var_name --> var self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list self.ops = list() # operator list
self.program = program self.program = program
self.removed_vars = dict() self.removed_vars = dict()
...@@ -831,6 +831,13 @@ class Block(object): ...@@ -831,6 +831,13 @@ class Block(object):
self.ops.append(op) self.ops.append(op)
return op return op
def insert_op(self, index, *args, **kwargs):
self.sync_with_cpp()
op_desc = self.desc.insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
return op
def delete_ops(self, ops): def delete_ops(self, ops):
# remove from cpp # remove from cpp
# FIXME(typhoonzero): remove only the first occurrence. # FIXME(typhoonzero): remove only the first occurrence.
...@@ -842,12 +849,12 @@ class Block(object): ...@@ -842,12 +849,12 @@ class Block(object):
self.desc.remove_op(start, end + 1) self.desc.remove_op(start, end + 1)
def slice_ops(self, start, end): def slice_ops(self, start, end):
return list(self.ops)[start:end] return self.ops[start:end]
def prepend_op(self, *args, **kwargs): def prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op() op_desc = self.desc.prepend_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
self.ops.appendleft(op) self.ops.insert(0, op)
return op return op
def sync_with_cpp(self): def sync_with_cpp(self):
...@@ -892,7 +899,7 @@ class Block(object): ...@@ -892,7 +899,7 @@ class Block(object):
for index in range((start_index - 1 - 1), -1, -1): for index in range((start_index - 1 - 1), -1, -1):
op_desc = ops_in_cpp[index] op_desc = ops_in_cpp[index]
op = Operator(self, op_desc) op = Operator(self, op_desc)
self.ops.appendleft(op) self.ops.insert(0, op)
# sync ops append to the end of cpp_ops # sync ops append to the end of cpp_ops
for index in range((end_index + 1), len(ops_in_cpp)): for index in range((end_index + 1), len(ops_in_cpp)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册