提交 56d971c7 编写于 作者: D Dong Daxiang 提交者: mapingshuo

RecomputeOptimizer: rm unused ckpt and sort ckpt (#20108) (#20336)

* rm unused ckpt and sort ckpt
上级 d4f11bd5
......@@ -19,14 +19,19 @@ from . import core
import collections
import copy
import six
import logging
from .. import compat as cpt
from . import unique_name
from . import log_helper
__all__ = [
'append_backward',
'gradients',
]
_logger = log_helper.get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class ProgramStats(object):
def __init__(self, block, ops):
......@@ -38,7 +43,7 @@ class ProgramStats(object):
def get_input_nodes(self):
input_names = []
for name in self.var_op_deps:
if len(self.var_op_deps[name]["var_as_output_ops"]) <= 0 and \
if len(self.var_op_deps[name]["var_as_output_ops"]) == 0 and \
len(self.var_op_deps[name]["var_as_input_ops"]) > 0:
if self.block.var(name).persistable:
continue
......@@ -115,6 +120,22 @@ class ProgramStats(object):
for op_idx in self.op_deps[i]["in_ops"]:
self.op_deps[op_idx]["out_ops"].extend([i])
def sort_checkpoints(self, checkpoints_name):
sorted_checkpoints = []
for name in checkpoints_name:
if name not in self.var_op_deps:
_logger.debug(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name)
elif self.var_op_deps[name]["var_as_output_ops"] == []:
# input nodes
sorted_checkpoints.append((name, -1))
else:
sorted_checkpoints.append(
(name, max(self.var_op_deps[name]["var_as_output_ops"])))
sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1])
return [x[0] for x in sorted_checkpoints]
def _pretty_op_desc_(op_desc, prefix):
out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \
......@@ -584,15 +605,17 @@ def _append_backward_ops_with_checkpoints_(
"""
checkpoints_name = [x.name for x in checkpoints]
checkpoints_name = list(set(checkpoints_name))
local_block = block.program._create_block()
buffer_block = block.program._create_block()
# 1) find ops between checkpoints, i.e. recompute_segments
program_stat = ProgramStats(block, ops)
program_stat.build_stats()
checkpoints_name = program_stat.sort_checkpoints(checkpoints_name)
segments = []
if len(checkpoints) == 1:
if len(checkpoints_name) == 1:
# only one checkpoint
max_op_idx = -1
var_group = [checkpoints_name[0]]
......@@ -616,8 +639,6 @@ def _append_backward_ops_with_checkpoints_(
segments.append([min_idx, max_idx + 1])
start_idx += 1
checkpoints_name = list(set(checkpoints_name))
if segments != [] and segments[0][0] != 0:
recompute_segments = [[0, segments[0][0]]] + segments
else:
......@@ -625,7 +646,7 @@ def _append_backward_ops_with_checkpoints_(
# 2) go through all forward ops and induct all variables that will be hold in memory
vars_should_be_hold = []
# a. variables that are used across segments will be held in memory
# a. variables that are used across segments will be held in memory
for segment in recompute_segments:
vars_should_be_hold.extend(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
......@@ -635,10 +656,6 @@ def _append_backward_ops_with_checkpoints_(
vars_should_be_hold.extend(program_stat.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold))
# find variables that can not be deleted
grad_should_be_hold = [x + "@GRAD" for x in vars_should_be_hold]
vars_should_be_hold.extend(grad_should_be_hold)
# 3) go through each recompute_segments, add backward ops with forward recomputation
grad_op_descs = []
var_name_dict = {}
......@@ -647,7 +664,7 @@ def _append_backward_ops_with_checkpoints_(
max_calculated_op_position = len(ops)
if recompute_segments == []:
# if there is no recompute segment, add backward ops like
# if there is no recompute segment, add backward ops like
# _append_backward_ops_ function
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
......
......@@ -614,7 +614,7 @@ class TestLookaheadOptimizer(unittest.TestCase):
class TestRecomputeOptimizer(unittest.TestCase):
def net(self):
def net(self, return_input=False):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
......@@ -652,6 +652,8 @@ class TestRecomputeOptimizer(unittest.TestCase):
block.append_op(
type="mean", inputs={"X": b2_out}, outputs={"Out": mean_out})
if return_input == True:
return mul_x, mul_out, b1_out, b2_out, mean_out
return mul_out, b1_out, b2_out, mean_out
def test_no_checkpoint(self):
......@@ -723,6 +725,42 @@ class TestRecomputeOptimizer(unittest.TestCase):
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
])
def test_out_of_order_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b2_out, mul_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 13)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add",
"elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd",
"sgd", "sgd"
])
def test_input_as_checkpoints(self):
mul_x, mul_out, b1_out, b2_out, mean_out = self.net(return_input=True)
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([mul_x, b2_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 14)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "mul", "elementwise_add",
"elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd",
"sgd", "sgd"
])
def test_apply_gradients(self):
mul_out, b1_out, b2_out, mean_out = self.net()
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册