未验证 提交 7a5f3f75 编写于 作者: W Wu Yi 提交者: GitHub

Fix memory optimization with dist train (#13535)

* show detail error log on ci

* test

* fix memopt and dist

* update apispec

* will fix different batch issue test=develop
上级 c8744d11
...@@ -21,7 +21,7 @@ paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'en ...@@ -21,7 +21,7 @@ paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'en
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174')) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174'))
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, False))
paddle.fluid.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.DistributeTranspilerConfig.__init__ paddle.fluid.DistributeTranspilerConfig.__init__
paddle.fluid.ParallelExecutor.__init__ ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)) paddle.fluid.ParallelExecutor.__init__ ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None))
...@@ -309,7 +309,7 @@ paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=[ ...@@ -309,7 +309,7 @@ paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=[
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self', 'wait_port'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174')) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program', 'current_endpoint'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None, '127.0.0.1:6174'))
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, False))
paddle.fluid.transpiler.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.HashName.__init__ ArgSpec(args=['self', 'pserver_endpoints'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.HashName.__init__ ArgSpec(args=['self', 'pserver_endpoints'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.HashName.dispatch ArgSpec(args=['self', 'varlist'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.HashName.dispatch ArgSpec(args=['self', 'varlist'], varargs=None, keywords=None, defaults=None)
......
...@@ -50,9 +50,7 @@ class TestDistRunnerBase(object): ...@@ -50,9 +50,7 @@ class TestDistRunnerBase(object):
def run_pserver(self, args): def run_pserver(self, args):
self.get_model(batch_size=2) self.get_model(batch_size=2)
# NOTE: pserver should not call memory optimize
if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program())
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints, fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode) args.trainers, args.sync_mode)
...@@ -70,7 +68,7 @@ class TestDistRunnerBase(object): ...@@ -70,7 +68,7 @@ class TestDistRunnerBase(object):
self.get_model(batch_size=2) self.get_model(batch_size=2)
if args.mem_opt: if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
if args.is_dist: if args.is_dist:
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), fluid.default_main_program(),
......
...@@ -26,14 +26,13 @@ class TestDistSeResneXt2x2(TestDistBase): ...@@ -26,14 +26,13 @@ class TestDistSeResneXt2x2(TestDistBase):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
# TODO(typhoonzero): fix this test class TestDistseResnXt2x2WithMemopt(TestDistBase):
# class TestDistseResnXt2x2WithMemopt(TestDistBase): def _setup_config(self):
# def _setup_config(self): self._sync_mode = True
# self._sync_mode = True self._mem_opt = True
# self._mem_opt = True
def test_dist_train(self):
# def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100)
# self.check_with_place("dist_se_resnext.py", delta=1e-7)
class TestDistSeResneXt2x2Async(TestDistBase): class TestDistSeResneXt2x2Async(TestDistBase):
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
from __future__ import print_function from __future__ import print_function
from collections import defaultdict, OrderedDict, Callable from collections import defaultdict, MutableSet
from .. import core from .. import core
from ... import compat as cpt from ... import compat as cpt
from ..framework import Program, default_main_program, Parameter, Variable from ..framework import Program, default_main_program, Parameter, Variable, core
from ..backward import _rename_arg_ from ..backward import _rename_arg_
from functools import reduce from functools import reduce
from six.moves import range from six.moves import range
...@@ -44,17 +44,82 @@ SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"), ...@@ -44,17 +44,82 @@ SUB_BLOCK_PAIR = [("while", "while_grad"), ("parallel_do", "parallel_do_grad"),
PRINT_LOG = False PRINT_LOG = False
class OrderedSet(MutableSet):
def __init__(self, iterable=None):
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
def add(self, key):
if key not in self.map:
end = self.end
curr = end[1]
curr[2] = end[1] = self.map[key] = [key, curr, end]
def update(self, other):
for e in other:
self.add(e)
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key)
prev[2] = next
next[1] = prev
def remove(self, key):
self.discard(key)
def __iter__(self):
end = self.end
curr = end[2]
while curr is not end:
yield curr[0]
curr = curr[2]
def __reversed__(self):
end = self.end
curr = end[1]
while curr is not end:
yield curr[0]
curr = curr[1]
def pop(self, last=True):
if not self:
raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__, )
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
class ControlFlowGraph(object): class ControlFlowGraph(object):
def __init__(self, program, ops, forward_num, skip_opt): def __init__(self, program, ops, forward_num, skip_opt):
self._program = program self._program = program
self._ops = ops self._ops = ops
self._forward_num = forward_num self._forward_num = forward_num
self._successors = defaultdict(set) self._successors = defaultdict(OrderedSet)
self._presuccessors = defaultdict(set) self._presuccessors = defaultdict(OrderedSet)
self._uses = defaultdict(set) self._uses = defaultdict(OrderedSet)
self._defs = defaultdict(set) self._defs = defaultdict(OrderedSet)
self._live_in = defaultdict(set) self._live_in = defaultdict(OrderedSet)
self._live_out = defaultdict(set) self._live_out = defaultdict(OrderedSet)
self._skip_opt = skip_opt self._skip_opt = skip_opt
self.pool = [] self.pool = []
...@@ -116,7 +181,7 @@ class ControlFlowGraph(object): ...@@ -116,7 +181,7 @@ class ControlFlowGraph(object):
# NOTE: must sort the in_diff set for cases that get different cache var. # NOTE: must sort the in_diff set for cases that get different cache var.
# FIXME(typhoonzero): maybe use a "sorted set" is better than this. # FIXME(typhoonzero): maybe use a "sorted set" is better than this.
can_optimize = [ can_optimize = [
x for x in sorted(list(in_diff)) x for x in in_diff
if self._check_var_validity(block_desc, x, is_forward) if self._check_var_validity(block_desc, x, is_forward)
] ]
if can_optimize: if can_optimize:
...@@ -224,7 +289,7 @@ class ControlFlowGraph(object): ...@@ -224,7 +289,7 @@ class ControlFlowGraph(object):
if self.pool: if self.pool:
# NOTE: must sort the in_diff set for cases that get different cache var. # NOTE: must sort the in_diff set for cases that get different cache var.
defs_can_optimize = [ defs_can_optimize = [
x for x in sorted(list(self._defs[i])) x for x in self._defs[i]
if self._check_var_validity(block_desc, x, is_forward) if self._check_var_validity(block_desc, x, is_forward)
] ]
out_pair = [ out_pair = [
...@@ -381,7 +446,19 @@ def _get_cfgs(input_program): ...@@ -381,7 +446,19 @@ def _get_cfgs(input_program):
return cfgs return cfgs
def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): def _is_opt_role_op(op):
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attr_names and \
int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
def memory_optimize(input_program,
skip_opt_set=None,
print_log=False,
level=0,
skip_grads=False):
"""Optimize memory by reusing var memory. """Optimize memory by reusing var memory.
Note: it doesn't not support subblock nested in subblock. Note: it doesn't not support subblock nested in subblock.
...@@ -398,6 +475,19 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): ...@@ -398,6 +475,19 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
raise ValueError("only support opt_level 0 or 1.") raise ValueError("only support opt_level 0 or 1.")
global PRINT_LOG global PRINT_LOG
PRINT_LOG = print_log PRINT_LOG = print_log
if skip_grads:
grad_set = set()
OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
for op in input_program.global_block().ops:
if _is_opt_role_op(op):
if op.attr(OP_ROLE_VAR):
grad_name = op.attr(OP_ROLE_VAR)[1]
grad_set.add(grad_name)
if not skip_opt_set:
skip_opt_set = grad_set
else:
skip_opt_set.update(grad_set)
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册