From 56d971c7fd0d847561cb0b9cf1941e7fc0d29343 Mon Sep 17 00:00:00 2001 From: Dong Daxiang <35550832+guru4elephant@users.noreply.github.com> Date: Fri, 11 Oct 2019 09:36:08 +0800 Subject: [PATCH] RecomputeOptimizer: rm unused ckpt and sort ckpt (#20108) (#20336) * rm unused ckpt and sort ckpt --- python/paddle/fluid/backward.py | 37 ++++++++++++----- .../fluid/tests/unittests/test_optimizer.py | 40 ++++++++++++++++++- 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 21a690a4328..1894ac41527 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -19,14 +19,19 @@ from . import core import collections import copy import six +import logging from .. import compat as cpt from . import unique_name +from . import log_helper __all__ = [ 'append_backward', 'gradients', ] +_logger = log_helper.get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + class ProgramStats(object): def __init__(self, block, ops): @@ -38,7 +43,7 @@ class ProgramStats(object): def get_input_nodes(self): input_names = [] for name in self.var_op_deps: - if len(self.var_op_deps[name]["var_as_output_ops"]) <= 0 and \ + if len(self.var_op_deps[name]["var_as_output_ops"]) == 0 and \ len(self.var_op_deps[name]["var_as_input_ops"]) > 0: if self.block.var(name).persistable: continue @@ -115,6 +120,22 @@ class ProgramStats(object): for op_idx in self.op_deps[i]["in_ops"]: self.op_deps[op_idx]["out_ops"].extend([i]) + def sort_checkpoints(self, checkpoints_name): + sorted_checkpoints = [] + for name in checkpoints_name: + if name not in self.var_op_deps: + _logger.debug( + "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." + % name) + elif self.var_op_deps[name]["var_as_output_ops"] == []: + # input nodes + sorted_checkpoints.append((name, -1)) + else: + sorted_checkpoints.append( + (name, max(self.var_op_deps[name]["var_as_output_ops"]))) + sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1]) + return [x[0] for x in sorted_checkpoints] + def _pretty_op_desc_(op_desc, prefix): out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \ @@ -584,15 +605,17 @@ def _append_backward_ops_with_checkpoints_( """ checkpoints_name = [x.name for x in checkpoints] + checkpoints_name = list(set(checkpoints_name)) local_block = block.program._create_block() buffer_block = block.program._create_block() # 1) find ops between checkpoints, i.e. recompute_segments program_stat = ProgramStats(block, ops) program_stat.build_stats() + checkpoints_name = program_stat.sort_checkpoints(checkpoints_name) segments = [] - if len(checkpoints) == 1: + if len(checkpoints_name) == 1: # only one checkpoint max_op_idx = -1 var_group = [checkpoints_name[0]] @@ -616,8 +639,6 @@ def _append_backward_ops_with_checkpoints_( segments.append([min_idx, max_idx + 1]) start_idx += 1 - checkpoints_name = list(set(checkpoints_name)) - if segments != [] and segments[0][0] != 0: recompute_segments = [[0, segments[0][0]]] + segments else: @@ -625,7 +646,7 @@ def _append_backward_ops_with_checkpoints_( # 2) go through all forward ops and induct all variables that will be hold in memory vars_should_be_hold = [] - # a. variables that are used across segments will be held in memory + # a. variables that are used across segments will be held in memory for segment in recompute_segments: vars_should_be_hold.extend( program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) @@ -635,10 +656,6 @@ def _append_backward_ops_with_checkpoints_( vars_should_be_hold.extend(program_stat.get_input_nodes()) vars_should_be_hold = list(set(vars_should_be_hold)) - # find variables that can not be deleted - grad_should_be_hold = [x + "@GRAD" for x in vars_should_be_hold] - vars_should_be_hold.extend(grad_should_be_hold) - # 3) go through each recompute_segments, add backward ops with forward recomputation grad_op_descs = [] var_name_dict = {} @@ -647,7 +664,7 @@ def _append_backward_ops_with_checkpoints_( max_calculated_op_position = len(ops) if recompute_segments == []: - # if there is no recompute segment, add backward ops like + # if there is no recompute segment, add backward ops like # _append_backward_ops_ function gap_ops = ops[0:max_calculated_op_position] for op in reversed(gap_ops): diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 1c3fd17fd28..a10a5e36228 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -614,7 +614,7 @@ class TestLookaheadOptimizer(unittest.TestCase): class TestRecomputeOptimizer(unittest.TestCase): - def net(self): + def net(self, return_input=False): program = framework.Program() block = program.global_block() mul_x = block.create_parameter( @@ -652,6 +652,8 @@ class TestRecomputeOptimizer(unittest.TestCase): block.append_op( type="mean", inputs={"X": b2_out}, outputs={"Out": mean_out}) + if return_input == True: + return mul_x, mul_out, b1_out, b2_out, mean_out return mul_out, b1_out, b2_out, mean_out def test_no_checkpoint(self): @@ -723,6 +725,42 @@ class TestRecomputeOptimizer(unittest.TestCase): "elementwise_add_grad", "mul_grad", "sgd", "sgd", "sgd" ]) + def test_out_of_order_checkpoint(self): + mul_out, b1_out, b2_out, mean_out = self.net() + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b2_out, mul_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 13) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "elementwise_add", + "elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd", + "sgd", "sgd" + ]) + + def test_input_as_checkpoints(self): + mul_x, mul_out, b1_out, b2_out, mean_out = self.net(return_input=True) + self.assertEqual(len(mean_out.block.ops), 4) + self.assertEqual([op.type for op in mean_out.block.ops], + ["mul", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([mul_x, b2_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 14) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "elementwise_add", "elementwise_add", "mean", + "fill_constant", "mean_grad", "mul", "elementwise_add", + "elementwise_add_grad", "elementwise_add_grad", "mul_grad", "sgd", + "sgd", "sgd" + ]) + def test_apply_gradients(self): mul_out, b1_out, b2_out, mean_out = self.net() sgd_optimizer = optimizer.SGD(learning_rate=1.0) -- GitLab