From 2617ac9d1a6fa45068748aff2bafd52b4e95efd6 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 17 Sep 2018 08:19:57 +0800 Subject: [PATCH] "add doc string" --- .../test_memory_optimization_transpiler.py | 8 +++++-- .../memory_optimization_transpiler.py | 24 +++++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) mode change 100644 => 100755 python/paddle/fluid/transpiler/memory_optimization_transpiler.py diff --git a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py index c288333ddb3..275e5c49d5c 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer from paddle.fluid.framework import Program, program_guard @@ -66,13 +67,16 @@ class TestMemoryTranspiler2(unittest.TestCase): print("after optimization") print(str(result_program)) + class TestMemoryTranspiler3(unittest.TestCase): def setUp(self): program = Program() with program_guard(program, startup_program=Program()): word = fluid.layers.data(name='word', shape=[1], dtype='int64') - emb = [fluid.layers.embedding(word, size=[65536, 256], param_attr='emb') - for _ in range(6)] + emb = [ + fluid.layers.embedding( + word, size=[65536, 256], param_attr='emb') for _ in range(6) + ] left = emb.pop(0) while len(emb) != 0: diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py old mode 100644 new mode 100755 index ba792d461c4..ac57a8b4ef9 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -96,7 +96,6 @@ class ControlFlowGraph(object): self._live_out[i].remove(old_name) self._live_out[i].add(new_name) - def _dataflow_analyze(self): self._build_graph() live_in = defaultdict(set) @@ -121,8 +120,8 @@ class ControlFlowGraph(object): ] if can_optimize: for var_name in can_optimize: - cache = (var_name, self._find_var( - block_desc, var_name, is_forward).shape()) + cache = (var_name, self._find_var(block_desc, var_name, + is_forward).shape()) if cache not in self.pool: self.pool.append(cache) @@ -232,7 +231,7 @@ class ControlFlowGraph(object): ] for x, x_shape in out_pair: if (x, x_shape) in self.pool: - raise ValueError("x in pool") + raise ValueError("x in pool, %s, %s" % (x, x_shape)) # If x is both in uses and defs, it can not be optimized! if x in self._uses[i]: continue @@ -240,9 +239,14 @@ class ControlFlowGraph(object): cache_var = cache_pair[0] cache_shape = cache_pair[1] if not self._has_var(block_desc, cache_var, is_forward): - raise ValueError("cache", cpt.to_text(cache_var), " Not exists!") + raise ValueError("cache", + cpt.to_text(cache_var), + " Not exists!") if x == cache_var: - raise ValueError("x : ", cpt.to_text(x), " cache : ", cpt.to_text(cache_var), " is same var!") + raise ValueError("x : ", + cpt.to_text(x), " cache : ", + cpt.to_text(cache_var), + " is same var!") x_dtype = self._find_var(block_desc, x, is_forward).dtype() @@ -266,14 +270,14 @@ class ControlFlowGraph(object): # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) - self._program.block(block_desc.id)._remove_var(cpt.to_text( - x)) + self._program.block(block_desc.id).var(cpt.to_text( + x)).desc = self._find_var(block_desc, cache_var, + is_forward) self._update_graph(x, cache_var, begin_idx=i) break self._fill_pool(i, is_forward) - def _process_sub_block_pair(pdesc, sub_block_pair): """Creates a list of tuple each of which tracks info of a subblock. @@ -379,7 +383,7 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): Note: it doesn't not support subblock nested in subblock. - :param input_program: Input Program + :param input_program(str): Input Program :param print_log: whether to print debug log. :param level: If level=0, reuse if the shape is completely equal, o :return: -- GitLab