From 697145cd2010e0b0ed87c7202f9a16d69d04f95d Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 29 Oct 2018 21:05:32 +0800 Subject: [PATCH] Merge pull request #14055 from dzhwinter/fix/mem_opt OrderSet default is mutableSet, not set. test=release/1.1 --- .../fluid/transpiler/memory_optimization_transpiler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 861bb5fae..c9f1be934 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -171,7 +171,7 @@ class ControlFlowGraph(object): self._live_out[i] |= self._live_in[s] self._live_in[i] = self._uses[i] | ( self._live_out[i] - self._defs[i]) - if live_in[i] != self._live_in[i]: + if live_in[i] != set(self._live_in[i]): for d in self._presuccessors[i]: worklist.append(d) @@ -321,8 +321,7 @@ class ControlFlowGraph(object): if not compare_shape(x_shape, cache_shape, level): continue - # TODO(qijun): actually, we should compare - # dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] + # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] if x_dtype != cache_dtype: continue @@ -487,7 +486,6 @@ def memory_optimize(input_program, skip_opt_set = grad_set else: skip_opt_set.update(grad_set) - cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) -- GitLab