From 7a5f3f750bcbe084796f7840ae2937925432b413 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Fri, 28 Sep 2018 12:58:39 +0800 Subject: [PATCH] Fix memory optimization with dist train (#13535) * show detail error log on ci * test * fix memopt and dist * update apispec * will fix different batch issue test=develop --- paddle/fluid/API.spec | 4 +- .../fluid/tests/unittests/test_dist_base.py | 6 +- .../tests/unittests/test_dist_se_resnext.py | 15 ++- .../memory_optimization_transpiler.py | 112 ++++++++++++++++-- 4 files changed, 112 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index def7ad9c40..452806a20e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -21,7 +21,7 @@ paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'en paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174')) -paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) +paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, False)) paddle.fluid.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspilerConfig.__init__ paddle.fluid.ParallelExecutor.__init__ ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)) @@ -309,7 +309,7 @@ paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=[ paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174')) -paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) +paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, False)) paddle.fluid.transpiler.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.HashName.__init__ ArgSpec(args=['self', 'pserver_endpoints'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.HashName.dispatch ArgSpec(args=['self', 'varlist'], varargs=None, keywords=None, defaults=None) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 856980e546..0b9af6d7f6 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -50,9 +50,7 @@ class TestDistRunnerBase(object): def run_pserver(self, args): self.get_model(batch_size=2) - - if args.mem_opt: - fluid.memory_optimize(fluid.default_main_program()) + # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, fluid.default_main_program(), args.endpoints, args.trainers, args.sync_mode) @@ -70,7 +68,7 @@ class TestDistRunnerBase(object): self.get_model(batch_size=2) if args.mem_opt: - fluid.memory_optimize(fluid.default_main_program()) + fluid.memory_optimize(fluid.default_main_program(), skip_grads=True) if args.is_dist: t = self.get_transpiler(args.trainer_id, fluid.default_main_program(), diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index d2d927aca8..c0989ca709 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -26,14 +26,13 @@ class TestDistSeResneXt2x2(TestDistBase): self.check_with_place("dist_se_resnext.py", delta=100) -# TODO(typhoonzero): fix this test -# class TestDistseResnXt2x2WithMemopt(TestDistBase): -# def _setup_config(self): -# self._sync_mode = True -# self._mem_opt = True - -# def test_dist_train(self): -# self.check_with_place("dist_se_resnext.py", delta=1e-7) +class TestDistseResnXt2x2WithMemopt(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._mem_opt = True + + def test_dist_train(self): + self.check_with_place("dist_se_resnext.py", delta=100) class TestDistSeResneXt2x2Async(TestDistBase): diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index d5aa54d752..861bb5fae5 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,10 +14,10 @@ from __future__ import print_function -from collections import defaultdict, OrderedDict, Callable +from collections import defaultdict, MutableSet from .. import core from ... import compat as cpt -from ..framework import Program, default_main_program, Parameter, Variable +from ..framework import Program, default_main_program, Parameter, Variable, core from ..backward import _rename_arg_ from functools import reduce from six.moves import range @@ -44,17 +44,82 @@ SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"), PRINT_LOG = False +class OrderedSet(MutableSet): + def __init__(self, iterable=None): + self.end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] + if iterable is not None: + self |= iterable + + def __len__(self): + return len(self.map) + + def __contains__(self, key): + return key in self.map + + def add(self, key): + if key not in self.map: + end = self.end + curr = end[1] + curr[2] = end[1] = self.map[key] = [key, curr, end] + + def update(self, other): + for e in other: + self.add(e) + + def discard(self, key): + if key in self.map: + key, prev, next = self.map.pop(key) + prev[2] = next + next[1] = prev + + def remove(self, key): + self.discard(key) + + def __iter__(self): + end = self.end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def pop(self, last=True): + if not self: + raise KeyError('set is empty') + key = self.end[1][0] if last else self.end[2][0] + self.discard(key) + return key + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__, ) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __eq__(self, other): + if isinstance(other, OrderedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) + + class ControlFlowGraph(object): def __init__(self, program, ops, forward_num, skip_opt): self._program = program self._ops = ops self._forward_num = forward_num - self._successors = defaultdict(set) - self._presuccessors = defaultdict(set) - self._uses = defaultdict(set) - self._defs = defaultdict(set) - self._live_in = defaultdict(set) - self._live_out = defaultdict(set) + self._successors = defaultdict(OrderedSet) + self._presuccessors = defaultdict(OrderedSet) + self._uses = defaultdict(OrderedSet) + self._defs = defaultdict(OrderedSet) + self._live_in = defaultdict(OrderedSet) + self._live_out = defaultdict(OrderedSet) self._skip_opt = skip_opt self.pool = [] @@ -116,7 +181,7 @@ class ControlFlowGraph(object): # NOTE: must sort the in_diff set for cases that get different cache var. # FIXME(typhoonzero): maybe use a "sorted set" is better than this. can_optimize = [ - x for x in sorted(list(in_diff)) + x for x in in_diff if self._check_var_validity(block_desc, x, is_forward) ] if can_optimize: @@ -224,7 +289,7 @@ class ControlFlowGraph(object): if self.pool: # NOTE: must sort the in_diff set for cases that get different cache var. defs_can_optimize = [ - x for x in sorted(list(self._defs[i])) + x for x in self._defs[i] if self._check_var_validity(block_desc, x, is_forward) ] out_pair = [ @@ -381,7 +446,19 @@ def _get_cfgs(input_program): return cfgs -def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): +def _is_opt_role_op(op): + op_maker = core.op_proto_and_checker_maker + optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize + if op_maker.kOpRoleAttrName() in op.attr_names and \ + int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): + return True + + +def memory_optimize(input_program, + skip_opt_set=None, + print_log=False, + level=0, + skip_grads=False): """Optimize memory by reusing var memory. Note: it doesn't not support subblock nested in subblock. @@ -398,6 +475,19 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): raise ValueError("only support opt_level 0 or 1.") global PRINT_LOG PRINT_LOG = print_log + if skip_grads: + grad_set = set() + OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName() + for op in input_program.global_block().ops: + if _is_opt_role_op(op): + if op.attr(OP_ROLE_VAR): + grad_name = op.attr(OP_ROLE_VAR)[1] + grad_set.add(grad_name) + if not skip_opt_set: + skip_opt_set = grad_set + else: + skip_opt_set.update(grad_set) + cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) -- GitLab