提交 56d971c7 编写于 作者: D Dong Daxiang 提交者: mapingshuo

RecomputeOptimizer: rm unused ckpt and sort ckpt (#20108) (#20336)

* rm unused ckpt and sort ckpt
上级 d4f11bd5
...@@ -19,14 +19,19 @@ from . import core ...@@ -19,14 +19,19 @@ from . import core
import collections import collections
import copy import copy
import six import six
import logging
from .. import compat as cpt from .. import compat as cpt
from . import unique_name from . import unique_name
from . import log_helper
__all__ = [ __all__ = [
'append_backward', 'append_backward',
'gradients', 'gradients',
] ]
_logger = log_helper.get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class ProgramStats(object): class ProgramStats(object):
def __init__(self, block, ops): def __init__(self, block, ops):
...@@ -38,7 +43,7 @@ class ProgramStats(object): ...@@ -38,7 +43,7 @@ class ProgramStats(object):
def get_input_nodes(self): def get_input_nodes(self):
input_names = [] input_names = []
for name in self.var_op_deps: for name in self.var_op_deps:
if len(self.var_op_deps[name]["var_as_output_ops"]) <= 0 and \ if len(self.var_op_deps[name]["var_as_output_ops"]) == 0 and \
len(self.var_op_deps[name]["var_as_input_ops"]) > 0: len(self.var_op_deps[name]["var_as_input_ops"]) > 0:
if self.block.var(name).persistable: if self.block.var(name).persistable:
continue continue
...@@ -115,6 +120,22 @@ class ProgramStats(object): ...@@ -115,6 +120,22 @@ class ProgramStats(object):
for op_idx in self.op_deps[i]["in_ops"]: for op_idx in self.op_deps[i]["in_ops"]:
self.op_deps[op_idx]["out_ops"].extend([i]) self.op_deps[op_idx]["out_ops"].extend([i])
def sort_checkpoints(self, checkpoints_name):
sorted_checkpoints = []
for name in checkpoints_name:
if name not in self.var_op_deps:
_logger.debug(
"Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program."
% name)
elif self.var_op_deps[name]["var_as_output_ops"] == []:
# input nodes
sorted_checkpoints.append((name, -1))
else:
sorted_checkpoints.append(
(name, max(self.var_op_deps[name]["var_as_output_ops"])))
sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1])
return [x[0] for x in sorted_checkpoints]
def _pretty_op_desc_(op_desc, prefix): def _pretty_op_desc_(op_desc, prefix):
out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \ out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \
...@@ -584,15 +605,17 @@ def _append_backward_ops_with_checkpoints_( ...@@ -584,15 +605,17 @@ def _append_backward_ops_with_checkpoints_(
""" """
checkpoints_name = [x.name for x in checkpoints] checkpoints_name = [x.name for x in checkpoints]
checkpoints_name = list(set(checkpoints_name))
local_block = block.program._create_block() local_block = block.program._create_block()
buffer_block = block.program._create_block() buffer_block = block.program._create_block()
# 1) find ops between checkpoints, i.e. recompute_segments # 1) find ops between checkpoints, i.e. recompute_segments
program_stat = ProgramStats(block, ops) program_stat = ProgramStats(block, ops)
program_stat.build_stats() program_stat.build_stats()
checkpoints_name = program_stat.sort_checkpoints(checkpoints_name)
segments = [] segments = []
if len(checkpoints) == 1: if len(checkpoints_name) == 1:
# only one checkpoint # only one checkpoint
max_op_idx = -1 max_op_idx = -1
var_group = [checkpoints_name[0]] var_group = [checkpoints_name[0]]
...@@ -616,8 +639,6 @@ def _append_backward_ops_with_checkpoints_( ...@@ -616,8 +639,6 @@ def _append_backward_ops_with_checkpoints_(
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
start_idx += 1 start_idx += 1
checkpoints_name = list(set(checkpoints_name))
if segments != [] and segments[0][0] != 0: if segments != [] and segments[0][0] != 0:
recompute_segments = [[0, segments[0][0]]] + segments recompute_segments = [[0, segments[0][0]]] + segments
else: else:
...@@ -635,10 +656,6 @@ def _append_backward_ops_with_checkpoints_( ...@@ -635,10 +656,6 @@ def _append_backward_ops_with_checkpoints_(
vars_should_be_hold.extend(program_stat.get_input_nodes()) vars_should_be_hold.extend(program_stat.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold)) vars_should_be_hold = list(set(vars_should_be_hold))
# find variables that can not be deleted
grad_should_be_hold = [x + "@GRAD" for x in vars_should_be_hold]
vars_should_be_hold.extend(grad_should_be_hold)
# 3) go through each recompute_segments, add backward ops with forward recomputation # 3) go through each recompute_segments, add backward ops with forward recomputation
grad_op_descs = [] grad_op_descs = []
var_name_dict = {} var_name_dict = {}
......
...@@ -614,7 +614,7 @@ class TestLookaheadOptimizer(unittest.TestCase): ...@@ -614,7 +614,7 @@ class TestLookaheadOptimizer(unittest.TestCase):
class TestRecomputeOptimizer(unittest.TestCase): class TestRecomputeOptimizer(unittest.TestCase):
def net(self): def net(self, return_input=False):
program = framework.Program() program = framework.Program()
block = program.global_block() block = program.global_block()
mul_x = block.create_parameter( mul_x = block.create_parameter(
...@@ -652,6 +652,8 @@ class TestRecomputeOptimizer(unittest.TestCase): ...@@ -652,6 +652,8 @@ class TestRecomputeOptimizer(unittest.TestCase):
block.append_op( block.append_op(
type="mean", inputs={"X": b2_out}, outputs={"Out": mean_out}) type="mean", inputs={"X": b2_out}, outputs={"Out": mean_out})
if return_input == True:
return mul_x, mul_out, b1_out, b2_out, mean_out
return mul_out, b1_out, b2_out, mean_out return mul_out, b1_out, b2_out, mean_out
def test_no_checkpoint(self): def test_no_checkpoint(self):
...@@ -723,6 +725,42 @@ class TestRecomputeOptimizer(unittest.TestCase): ...@@ -723,6 +725,42 @@ class TestRecomputeOptimizer(unittest.TestCase):
"elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd"
]) ])
def test_out_of_order_checkpoint(self):
mul_out, b1_out, b2_out, mean_out = self.net()
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b2_out, mul_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 13)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "elementwise_add",
"elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd",
"sgd", "sgd"
])
def test_input_as_checkpoints(self):
mul_x, mul_out, b1_out, b2_out, mean_out = self.net(return_input=True)
self.assertEqual(len(mean_out.block.ops), 4)
self.assertEqual([op.type for op in mean_out.block.ops],
["mul", "elementwise_add", "elementwise_add", "mean"])
sgd_optimizer = optimizer.SGD(learning_rate=1.0)
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([mul_x, b2_out])
opts, params_grads = recompute_optimizer.minimize(mean_out)
self.assertEqual(len(mean_out.block.ops), 14)
self.assertEqual([op.type for op in mean_out.block.ops], [
"mul", "elementwise_add", "elementwise_add", "mean",
"fill_constant", "mean_grad", "mul", "elementwise_add",
"elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd",
"sgd", "sgd"
])
def test_apply_gradients(self): def test_apply_gradients(self):
mul_out, b1_out, b2_out, mean_out = self.net() mul_out, b1_out, b2_out, mean_out = self.net()
sgd_optimizer = optimizer.SGD(learning_rate=1.0) sgd_optimizer = optimizer.SGD(learning_rate=1.0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册