From 5212b2a9699621271dc40f4d563c05ec0abd76bf Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 17 Sep 2018 17:12:22 +0800 Subject: [PATCH] "rerun" --- .../memory_optimization_transpiler.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index ac57a8b4e..76adedfad 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -47,7 +47,6 @@ PRINT_LOG = False class ControlFlowGraph(object): def __init__(self, program, ops, forward_num, skip_opt): self._program = program - self._dup_program = program.clone() self._ops = ops self._forward_num = forward_num self._successors = defaultdict(set) @@ -230,8 +229,6 @@ class ControlFlowGraph(object): for x in defs_can_optimize ] for x, x_shape in out_pair: - if (x, x_shape) in self.pool: - raise ValueError("x in pool, %s, %s" % (x, x_shape)) # If x is both in uses and defs, it can not be optimized! if x in self._uses[i]: continue @@ -239,14 +236,15 @@ class ControlFlowGraph(object): cache_var = cache_pair[0] cache_shape = cache_pair[1] if not self._has_var(block_desc, cache_var, is_forward): - raise ValueError("cache", - cpt.to_text(cache_var), - " Not exists!") + if PRINT_LOG: + print("cache %s not exists!" % + (cpt.to_text(cache_var))) + continue if x == cache_var: - raise ValueError("x : ", - cpt.to_text(x), " cache : ", - cpt.to_text(cache_var), - " is same var!") + if PRINT_LOG: + print("x : ", cpt.to_text(x), " cache : ", + cpt.to_text(cache_var), " is same var!") + break x_dtype = self._find_var(block_desc, x, is_forward).dtype() @@ -383,10 +381,13 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): Note: it doesn't not support subblock nested in subblock. - :param input_program(str): Input Program - :param print_log: whether to print debug log. - :param level: If level=0, reuse if the shape is completely equal, o - :return: + Args: + input_program(str): Input Program + skip_opt_set(set): vars wil be skipped in memory optimze + print_log(bool): whether to print debug log. + level(int): If level=0, reuse if the shape is completely equal, o + Returns: + None """ if level != 0 and level != 1: raise ValueError("only support opt_level 0 or 1.") @@ -407,6 +408,9 @@ def release_memory(input_program, skip_opt_set=None): Args: input_program(Program): The program will be inserted :code:`delete_op`. + skip_opt_set(set): vars wil be skipped in memory optimze + Returns: + None """ cfgs = _get_cfgs(input_program) for cfg in cfgs: -- GitLab