未验证 提交 c1ac5b63 编写于 作者: Q QI JUN 提交者: GitHub

memory optimization for dynamic RNN (#8041)

* init

* add delete operator

* debug

* add wait

* clean code

* fix bug

* fix bug

* refine code

* remove unused code
上级 292c1951
......@@ -99,6 +99,9 @@ class WhileGradOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
......@@ -205,6 +208,8 @@ class WhileGradOp : public framework::OperatorBase {
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
dev_ctx.Wait();
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
}
}
};
......
......@@ -31,7 +31,7 @@ dtype_to_size = {
class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num):
def __init__(self, Program, ops, forward_num, skip_opt):
self._program = Program
self._ops = ops
self._forward_num = forward_num
......@@ -41,6 +41,7 @@ class ControlFlowGraph(object):
self._defs = defaultdict(set)
self._live_in = defaultdict(set)
self._live_out = defaultdict(set)
self._skip_opt = skip_opt
def _add_connections(self, connections):
for node1, node2 in connections:
......@@ -130,6 +131,10 @@ class ControlFlowGraph(object):
block_desc, x,
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
return False
if x in self._skip_opt:
return False
if not self._find_var(block_desc, x, is_forward).shape():
return False
return True
self._build_graph()
......@@ -140,6 +145,7 @@ class ControlFlowGraph(object):
if op.type() == "while" or op.type() == "while_grad":
continue
block_desc = op.block()
self.current_block_desc = block_desc
is_forward = i < self._forward_num
if self.pool:
defs_can_optimize = filter(
......@@ -197,28 +203,32 @@ def get_cfgs(input_program):
block_desc = pdesc.block(0)
op_size = block_desc.op_size()
# Get global block ops
ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size))
ops_list.append(
([block_desc.op(i) for i in range(op_size)], op_size, set()))
while_sub_block_ids = []
while_grad_sub_block_ids = []
while_pair = []
while_op_output = set()
while_block_id_pair = []
for i in range(op_size):
op = block_desc.op(i)
if op.type() == "while":
while_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names())
elif op.type() == "while_grad":
while_grad_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names())
# Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent
if parent_id in while_sub_block_ids:
while_pair.append((parent_id, grad_id))
while_block_id_pair.append((parent_id, grad_id))
while_sub_block_ids.remove(parent_id)
# Get while/while_grad block ops
for parent_id, grad_id in while_pair:
for parent_id, grad_id in while_block_id_pair:
while_block_ops = []
while_block = pdesc.block(parent_id)
while_block_op_size = while_block.op_size()
......@@ -230,7 +240,7 @@ def get_cfgs(input_program):
for i in range(while_grad_block_op_size):
while_block_ops.append(while_grad_block.op(i))
ops_list.append((while_block_ops, while_block_op_size))
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops
for parent_id in while_sub_block_ids:
......@@ -242,7 +252,7 @@ def get_cfgs(input_program):
ops_list.append((while_block_ops, while_block_op_size))
cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list]
cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list]
return cfgs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册