未验证 提交 abe32a49 编写于 作者: D dzhwinter 提交者: GitHub

"fix memory optimize bug in lr decay" (#12299)

上级 c2fe067e
...@@ -324,6 +324,8 @@ def _process_sub_block_pair(pdesc, sub_block_pair): ...@@ -324,6 +324,8 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
sub_op_output = set() sub_op_output = set()
sub_op_output.update(sub_op_dict[fwd_id].output_arg_names()) sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
sub_op_output.update(sub_op_dict[grad_id].output_arg_names()) sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
sub_op_output.update(sub_op_dict[grad_id].input_arg_names())
ops_list.append((sub_block_ops, block_op_size, sub_op_output)) ops_list.append((sub_block_ops, block_op_size, sub_op_output))
# Process rest fwd_op block ops # Process rest fwd_op block ops
...@@ -335,6 +337,7 @@ def _process_sub_block_pair(pdesc, sub_block_pair): ...@@ -335,6 +337,7 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
sub_block_ops.append(sub_block.op(i)) sub_block_ops.append(sub_block.op(i))
sub_op_output = set() sub_op_output = set()
sub_op_output.update(sub_op_dict[fwd_id].output_arg_names()) sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output)) ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
return ops_list return ops_list
...@@ -349,13 +352,17 @@ def _get_cfgs(input_program): ...@@ -349,13 +352,17 @@ def _get_cfgs(input_program):
pdesc = input_program.get_desc() pdesc = input_program.get_desc()
block_desc = pdesc.block(0) block_desc = pdesc.block(0)
op_size = block_desc.op_size() op_size = block_desc.op_size()
# Get global block ops
ops_list.append(
([block_desc.op(i) for i in range(op_size)], op_size, set()))
# Only process one level of nested subblock. # Only process one level of nested subblock.
ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR)) ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))
skip_opt_set = set()
for _, _, skip_opt in ops_list:
skip_opt_set.update(skip_opt)
# Get global block ops
ops_list.insert(
0, ([block_desc.op(i) for i in range(op_size)], op_size, skip_opt_set))
cfgs = [ cfgs = [
ControlFlowGraph(input_program, ops, forward_num, skip_opt) ControlFlowGraph(input_program, ops, forward_num, skip_opt)
for ops, forward_num, skip_opt in ops_list for ops, forward_num, skip_opt in ops_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册