From ef60a6544e9d93a9edc34b7aed33d477825c0a1e Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Fri, 14 Sep 2018 17:12:19 +0800 Subject: [PATCH] "add test" --- .../test_memory_optimization_transpiler.py | 25 ++++ .../memory_optimization_transpiler.py | 115 ++++++------------ 2 files changed, 65 insertions(+), 75 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py index 67733807f8..dc5bdd2bf5 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py @@ -66,6 +66,31 @@ class TestMemoryTranspiler2(unittest.TestCase): print("after optimization") print(str(result_program)) +class TestMemoryTranspiler3(unittest.TestCase): + def setUp(self): + program = Program() + with program_guard(program, startup_program=Program()): + word = fluid.layers.data(name='word', shape=[1], dtype='int64') + emb = [fluid.layers.embedding(word, size=[65536, 256], param_attr='emb') + for _ in range(6)] + + left = emb.pop(0) + while len(emb) != 0: + right = emb.pop(0) + left = fluid.layers.concat([left, right]) + emb = fluid.layers.mean(left) + fluid.backward.append_backward(emb) + self.program = program + + def test_cascade_reuse(self): + block = self.program.block(0) + # variable reuse in programdesc + self.assertTrue("concat_4.tmp_0@GRAD" in block.vars) + self.assertTrue("concat_3.tmp_0@GRAD" not in block.vars) + self.assertTrue("concat_2.tmp_0@GRAD" not in block.vars) + self.assertTrue("concat_1.tmp_0@GRAD" not in block.vars) + self.assertTrue("concat_0.tmp_0@GRAD" not in block.vars) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index e6fb8a91a8..b512534882 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -47,6 +47,7 @@ PRINT_LOG = False class ControlFlowGraph(object): def __init__(self, program, ops, forward_num, skip_opt): self._program = program + self._dup_program = program.clone() self._ops = ops self._forward_num = forward_num self._successors = defaultdict(set) @@ -56,6 +57,7 @@ class ControlFlowGraph(object): self._live_in = defaultdict(set) self._live_out = defaultdict(set) self._skip_opt = skip_opt + self.pool = [] def _add_connections(self, connections): """Populates _successors and _presuccessors for two neighbor nodes.""" @@ -78,8 +80,6 @@ class ControlFlowGraph(object): self._uses[i].update(self._ops[i].input_arg_names()) self._defs[i].update(self._ops[i].output_arg_names()) self._live_in[i] = self._uses[i] - # print(self._successors) - # print(self._presuccessors) def _update_graph(self, old_name, new_name, begin_idx=0): for i in range(begin_idx, self.op_size): @@ -89,50 +89,13 @@ class ControlFlowGraph(object): if old_name in self._defs[i]: self._defs[i].remove(old_name) self._defs[i].add(new_name) - # for i in range(begin_idx, -1, -1): if old_name in self._live_in[i]: self._live_in[i].remove(old_name) self._live_in[i].add(new_name) - # if old_name == "concat_3.tmp_0@GRAD": - # print("new_name", new_name) - # print("live_in ", i , self._live_in[i]) if old_name in self._live_out[i]: self._live_out[i].remove(old_name) self._live_out[i].add(new_name) - # if old_name == "concat_3.tmp_0@GRAD": - # print("live_out ", i , self._live_out[i]) - - def _reach_fixed_point(self, live_in, live_out): - """Check if the liveness set has stablized.""" - if len(live_in) != len(self._live_in): - return False - if len(live_out) != len(self._live_out): - return False - for i in range(self.op_size): - if (live_in[i] != self._live_in[i] or - live_out[i] != self._live_out[i]): - return False - return True - # def _dataflow_analyze(self): - # self._build_graph() - # live_in = defaultdict(set) - # live_out = defaultdict(set) - # # Repeatedly apply liveness updates until the algorithm stablize - # # on a complete set live input vars and live output vars. - # counter = 0 - # print(self._successors) - # while True: - # counter += 1 - # for i in reversed(list(range(self.op_size))): - # live_in[i] = set(self._live_in[i]) - # live_out[i] = set(self._live_out[i]) - # for s in self._successors[i]: - # self._live_out[i] |= self._live_in[s] - # self._live_in[i] = self._uses[i] | ( - # self._live_out[i] - self._defs[i]) - # if self._reach_fixed_point(live_in, live_out): - # break def _dataflow_analyze(self): self._build_graph() @@ -149,6 +112,20 @@ class ControlFlowGraph(object): for d in self._presuccessors[i]: worklist.append(d) + def _fill_pool(self, i, is_forward): + block_desc = self._ops[i].block() + in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) + can_optimize = [ + x for x in in_diff + if self._check_var_validity(block_desc, x, is_forward) + ] + if can_optimize: + for var_name in can_optimize: + cache = (var_name, self._find_var( + block_desc, var_name, is_forward).shape()) + if cache not in self.pool: + self.pool.append(cache) + def _get_diff(self, a, b): u = a & b return a - u, b - u @@ -238,24 +215,15 @@ class ControlFlowGraph(object): # update skip set to meet users' demand if skip_opt_set: self._skip_opt.update(skip_opt_set) - self.pool = [] + # self.pool = [] for i in range(self.op_size): op = self._ops[i] if op.type() in SUB_BLOCK_OPS: continue block_desc = op.block() is_forward = i < self._forward_num - in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) - can_optimize = [ - x for x in in_diff - if self._check_var_validity(block_desc, x, is_forward) - ] - if can_optimize: - for var_name in can_optimize: - self.pool.append((var_name, self._find_var( - block_desc, var_name, is_forward).shape())) + self._fill_pool(i, is_forward) # print(op.type(), i, self.pool) - # print(self._live_in[i]) if self.pool: defs_can_optimize = [ x for x in self._defs[i] @@ -266,60 +234,57 @@ class ControlFlowGraph(object): for x in defs_can_optimize ] for x, x_shape in out_pair: + if (x, x_shape) in self.pool: + raise ValueError("x in pool") # If x is both in uses and defs, it can not be optimized! if x in self._uses[i]: + # print(self.pool, op.type(), cpt.to_text(x)) + # raise ValueError("x in use!", cpt.to_text(x)) continue for index, cache_pair in enumerate(self.pool): cache_var = cache_pair[0] cache_shape = cache_pair[1] - if not compare_shape(x_shape, cache_shape, level): - continue - if not self._has_var(block_desc, cache_var, is_forward): - continue + raise ValueError("cache", cpt.to_text(cache_var), " Not exists!") + if x == cache_var: + raise ValueError("x : ", cpt.to_text(x), " cache : ", cpt.to_text(cache_var), " is same var!") x_dtype = self._find_var(block_desc, x, is_forward).dtype() cache_dtype = self._find_var(block_desc, cache_var, is_forward).dtype() + + if not compare_shape(x_shape, cache_shape, level): + continue # TODO(qijun): actually, we should compare # dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] if x_dtype != cache_dtype: continue - self.pool.pop(index) - if x == cache_var: - break - if PRINT_LOG: print(("Hit Cache !!!! cache pool index " "is %d, var name is %s, " "cached var name is %s, " "var shape is %s ") % (index, x, cache_var, str(cache_shape))) + self.pool.pop(index) # Rename the var to the cache var already with # memory allocated in order to reuse the memory. _rename_arg_(self._ops, x, cache_var, begin_idx=i) - self._program.block(block_desc.id).var(cpt.to_text( - x)).desc = self._find_var(block_desc, cache_var, - is_forward) - if x == "concat_3.tmp_0@GRAD": - print("Update Graph", i) + self._program.block(block_desc.id)._remove_var(cpt.to_text( + x)) + # if str(self._program) != str(self._dup_program): + # with open("./program_middle", "w") as f: + # f.write(str(self._program)) + # f.flush() + # exit(0) + # self._program.block(block_desc.id).var(cpt.to_text( + # x)).desc = self._find_var(block_desc, cache_var, + # is_forward) self._update_graph(x, cache_var, begin_idx=i) break + # self._fill_pool(i, is_forward) - in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i]) - can_optimize = [ - x for x in in_diff - if self._check_var_validity(block_desc, x, is_forward) - ] - keys = set([key for key,shape in self.pool]) - if can_optimize: - for var_name in can_optimize: - if var_name not in keys: - self.pool.append((var_name, self._find_var( - block_desc, var_name, is_forward).shape())) - # print(op.type(), i, self.pool) def _process_sub_block_pair(pdesc, sub_block_pair): -- GitLab