diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 3e58e125de4188144646236f7999c620cd8e9459..e6fb8a91a8d72c6a1410adafe69ea4b13962033f 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -77,6 +77,9 @@ class ControlFlowGraph(object): for i in range(self.op_size): self._uses[i].update(self._ops[i].input_arg_names()) self._defs[i].update(self._ops[i].output_arg_names()) + self._live_in[i] = self._uses[i] + # print(self._successors) + # print(self._presuccessors) def _update_graph(self, old_name, new_name, begin_idx=0): for i in range(begin_idx, self.op_size): @@ -86,12 +89,18 @@ class ControlFlowGraph(object): if old_name in self._defs[i]: self._defs[i].remove(old_name) self._defs[i].add(new_name) + # for i in range(begin_idx, -1, -1): if old_name in self._live_in[i]: self._live_in[i].remove(old_name) - self._live_out[i].add(new_name) + self._live_in[i].add(new_name) + # if old_name == "concat_3.tmp_0@GRAD": + # print("new_name", new_name) + # print("live_in ", i , self._live_in[i]) if old_name in self._live_out[i]: self._live_out[i].remove(old_name) self._live_out[i].add(new_name) + # if old_name == "concat_3.tmp_0@GRAD": + # print("live_out ", i , self._live_out[i]) def _reach_fixed_point(self, live_in, live_out): """Check if the liveness set has stablized.""" @@ -105,22 +114,40 @@ class ControlFlowGraph(object): return False return True + # def _dataflow_analyze(self): + # self._build_graph() + # live_in = defaultdict(set) + # live_out = defaultdict(set) + # # Repeatedly apply liveness updates until the algorithm stablize + # # on a complete set live input vars and live output vars. + # counter = 0 + # print(self._successors) + # while True: + # counter += 1 + # for i in reversed(list(range(self.op_size))): + # live_in[i] = set(self._live_in[i]) + # live_out[i] = set(self._live_out[i]) + # for s in self._successors[i]: + # self._live_out[i] |= self._live_in[s] + # self._live_in[i] = self._uses[i] | ( + # self._live_out[i] - self._defs[i]) + # if self._reach_fixed_point(live_in, live_out): + # break + def _dataflow_analyze(self): self._build_graph() live_in = defaultdict(set) - live_out = defaultdict(set) - # Repeatedly apply liveness updates until the algorithm stablize - # on a complete set live input vars and live output vars. - while True: - for i in reversed(list(range(self.op_size))): - live_in[i] = set(self._live_in[i]) - live_out[i] = set(self._live_out[i]) - for s in self._successors[i]: - self._live_out[i] |= self._live_in[s] - self._live_in[i] = self._uses[i] | ( - self._live_out[i] - self._defs[i]) - if self._reach_fixed_point(live_in, live_out): - break + worklist = list(range(len(self._ops) - 1, -1, -1)) + while worklist: + i = worklist.pop(0) + live_in[i] = set(self._live_in[i]) + for s in self._successors[i]: + self._live_out[i] |= self._live_in[s] + self._live_in[i] = self._uses[i] | ( + self._live_out[i] - self._defs[i]) + if live_in[i] != self._live_in[i]: + for d in self._presuccessors[i]: + worklist.append(d) def _get_diff(self, a, b): u = a & b @@ -218,6 +245,17 @@ class ControlFlowGraph(object): continue block_desc = op.block() is_forward = i < self._forward_num + in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) + can_optimize = [ + x for x in in_diff + if self._check_var_validity(block_desc, x, is_forward) + ] + if can_optimize: + for var_name in can_optimize: + self.pool.append((var_name, self._find_var( + block_desc, var_name, is_forward).shape())) + # print(op.type(), i, self.pool) + # print(self._live_in[i]) if self.pool: defs_can_optimize = [ x for x in self._defs[i] @@ -249,21 +287,24 @@ class ControlFlowGraph(object): if x_dtype != cache_dtype: continue + self.pool.pop(index) + if x == cache_var: + break + if PRINT_LOG: print(("Hit Cache !!!! cache pool index " "is %d, var name is %s, " "cached var name is %s, " "var shape is %s ") % (index, x, cache_var, str(cache_shape))) - self.pool.pop(index) - if x == cache_var: - break # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) self._program.block(block_desc.id).var(cpt.to_text( x)).desc = self._find_var(block_desc, cache_var, is_forward) + if x == "concat_3.tmp_0@GRAD": + print("Update Graph", i) self._update_graph(x, cache_var, begin_idx=i) break @@ -272,10 +313,13 @@ class ControlFlowGraph(object): x for x in in_diff if self._check_var_validity(block_desc, x, is_forward) ] + keys = set([key for key,shape in self.pool]) if can_optimize: for var_name in can_optimize: - self.pool.append((var_name, self._find_var( - block_desc, var_name, is_forward).shape())) + if var_name not in keys: + self.pool.append((var_name, self._find_var( + block_desc, var_name, is_forward).shape())) + # print(op.type(), i, self.pool) def _process_sub_block_pair(pdesc, sub_block_pair):