未验证 提交 c1ac5b63 编写于 作者: Q QI JUN 提交者: GitHub

memory optimization for dynamic RNN (#8041)

* init

* add delete operator

* debug

* add wait

* clean code

* fix bug

* fix bug

* refine code

* remove unused code
上级 292c1951
...@@ -99,6 +99,9 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -99,6 +99,9 @@ class WhileGradOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
...@@ -205,6 +208,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -205,6 +208,8 @@ class WhileGradOp : public framework::OperatorBase {
sum_op->Run(cur_scope, dev_place); sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
dev_ctx.Wait();
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
} }
} }
}; };
......
...@@ -31,7 +31,7 @@ dtype_to_size = { ...@@ -31,7 +31,7 @@ dtype_to_size = {
class ControlFlowGraph(object): class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num): def __init__(self, Program, ops, forward_num, skip_opt):
self._program = Program self._program = Program
self._ops = ops self._ops = ops
self._forward_num = forward_num self._forward_num = forward_num
...@@ -41,6 +41,7 @@ class ControlFlowGraph(object): ...@@ -41,6 +41,7 @@ class ControlFlowGraph(object):
self._defs = defaultdict(set) self._defs = defaultdict(set)
self._live_in = defaultdict(set) self._live_in = defaultdict(set)
self._live_out = defaultdict(set) self._live_out = defaultdict(set)
self._skip_opt = skip_opt
def _add_connections(self, connections): def _add_connections(self, connections):
for node1, node2 in connections: for node1, node2 in connections:
...@@ -130,6 +131,10 @@ class ControlFlowGraph(object): ...@@ -130,6 +131,10 @@ class ControlFlowGraph(object):
block_desc, x, block_desc, x,
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR: is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
return False return False
if x in self._skip_opt:
return False
if not self._find_var(block_desc, x, is_forward).shape():
return False
return True return True
self._build_graph() self._build_graph()
...@@ -140,6 +145,7 @@ class ControlFlowGraph(object): ...@@ -140,6 +145,7 @@ class ControlFlowGraph(object):
if op.type() == "while" or op.type() == "while_grad": if op.type() == "while" or op.type() == "while_grad":
continue continue
block_desc = op.block() block_desc = op.block()
self.current_block_desc = block_desc
is_forward = i < self._forward_num is_forward = i < self._forward_num
if self.pool: if self.pool:
defs_can_optimize = filter( defs_can_optimize = filter(
...@@ -197,28 +203,32 @@ def get_cfgs(input_program): ...@@ -197,28 +203,32 @@ def get_cfgs(input_program):
block_desc = pdesc.block(0) block_desc = pdesc.block(0)
op_size = block_desc.op_size() op_size = block_desc.op_size()
# Get global block ops # Get global block ops
ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size)) ops_list.append(
([block_desc.op(i) for i in range(op_size)], op_size, set()))
while_sub_block_ids = [] while_sub_block_ids = []
while_grad_sub_block_ids = [] while_grad_sub_block_ids = []
while_pair = [] while_op_output = set()
while_block_id_pair = []
for i in range(op_size): for i in range(op_size):
op = block_desc.op(i) op = block_desc.op(i)
if op.type() == "while": if op.type() == "while":
while_sub_block_ids.append(op.attr("sub_block").id) while_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names())
elif op.type() == "while_grad": elif op.type() == "while_grad":
while_grad_sub_block_ids.append(op.attr("sub_block").id) while_grad_sub_block_ids.append(op.attr("sub_block").id)
while_op_output.update(op.output_arg_names())
# Find while/while_grad block pair # Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids: for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent parent_id = pdesc.block(grad_id).parent
if parent_id in while_sub_block_ids: if parent_id in while_sub_block_ids:
while_pair.append((parent_id, grad_id)) while_block_id_pair.append((parent_id, grad_id))
while_sub_block_ids.remove(parent_id) while_sub_block_ids.remove(parent_id)
# Get while/while_grad block ops # Get while/while_grad block ops
for parent_id, grad_id in while_pair: for parent_id, grad_id in while_block_id_pair:
while_block_ops = [] while_block_ops = []
while_block = pdesc.block(parent_id) while_block = pdesc.block(parent_id)
while_block_op_size = while_block.op_size() while_block_op_size = while_block.op_size()
...@@ -230,7 +240,7 @@ def get_cfgs(input_program): ...@@ -230,7 +240,7 @@ def get_cfgs(input_program):
for i in range(while_grad_block_op_size): for i in range(while_grad_block_op_size):
while_block_ops.append(while_grad_block.op(i)) while_block_ops.append(while_grad_block.op(i))
ops_list.append((while_block_ops, while_block_op_size)) ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops # Process rest while block ops
for parent_id in while_sub_block_ids: for parent_id in while_sub_block_ids:
...@@ -242,7 +252,7 @@ def get_cfgs(input_program): ...@@ -242,7 +252,7 @@ def get_cfgs(input_program):
ops_list.append((while_block_ops, while_block_op_size)) ops_list.append((while_block_ops, while_block_op_size))
cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list] cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list]
return cfgs return cfgs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册