From 142fac18ec41abe570147c44e6b434f807efae88 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 7 Mar 2018 16:27:38 +0800 Subject: [PATCH] add print_log to memory_optimize --- python/paddle/fluid/memory_optimization_transpiler.py | 8 ++++++-- .../book_memory_optimization/test_memopt_fit_a_line.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/memory_optimization_transpiler.py b/python/paddle/fluid/memory_optimization_transpiler.py index 708ca08b17..e82456a99f 100644 --- a/python/paddle/fluid/memory_optimization_transpiler.py +++ b/python/paddle/fluid/memory_optimization_transpiler.py @@ -31,6 +31,8 @@ dtype_to_size = { sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"] +PRINT_LOG = False + class ControlFlowGraph(object): def __init__(self, Program, ops, forward_num, skip_opt): @@ -170,7 +172,7 @@ class ControlFlowGraph(object): block_desc, cache_var, is_forward).dtype() # TODO(qijun): actually, we should compare dtype_to_size[x_dtype] # and dtype_to_size[cache_dtype] - if x_dtype == cache_dtype: + if x_dtype == cache_dtype and PRINT_LOG: print(("Hit Cache !!!! cache pool index " "is %d, var name is %s, " "cached var name is %s, " @@ -277,7 +279,9 @@ def _get_cfgs(input_program): return cfgs -def memory_optimize(input_program): +def memory_optimize(input_program, print_log=False): + global PRINT_LOG + PRINT_LOG = print_log cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize() diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index 7648bb9fe1..c9d2a5ecaa 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) sgd_optimizer.minimize(avg_cost) -fluid.memory_optimize(fluid.default_main_program()) +fluid.memory_optimize(fluid.default_main_program(), print_log=True) BATCH_SIZE = 200 -- GitLab