diff --git a/python/paddle/fluid/memory_optimization_transpiler.py b/python/paddle/fluid/memory_optimization_transpiler.py index 708ca08b17c85efbb25ecaf2580b7141421e25b9..4fa2d03ef563b98b2eec576bf87d4b2e54ca0a36 100644 --- a/python/paddle/fluid/memory_optimization_transpiler.py +++ b/python/paddle/fluid/memory_optimization_transpiler.py @@ -31,6 +31,8 @@ dtype_to_size = { sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"] +PRINT_LOG = False + class ControlFlowGraph(object): def __init__(self, Program, ops, forward_num, skip_opt): @@ -171,12 +173,14 @@ class ControlFlowGraph(object): # TODO(qijun): actually, we should compare dtype_to_size[x_dtype] # and dtype_to_size[cache_dtype] if x_dtype == cache_dtype: - print(("Hit Cache !!!! cache pool index " - "is %d, var name is %s, " - "cached var name is %s, " - "var shape is %s ") % - (index, x, cache_var, - str(cache_shape))) + if PRINT_LOG: + print( + ("Hit Cache !!!! cache pool index " + "is %d, var name is %s, " + "cached var name is %s, " + "var shape is %s ") % + (index, x, cache_var, + str(cache_shape))) self.pool.pop(index) if x == cache_var: break @@ -277,7 +281,9 @@ def _get_cfgs(input_program): return cfgs -def memory_optimize(input_program): +def memory_optimize(input_program, print_log=False): + global PRINT_LOG + PRINT_LOG = print_log cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize() diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index 7648bb9fe1c8d2dae3590fa67f781743ba297c32..c9d2a5ecaab0669f308b5b9c5cf74d0212fa462a 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) sgd_optimizer.minimize(avg_cost) -fluid.memory_optimize(fluid.default_main_program()) +fluid.memory_optimize(fluid.default_main_program(), print_log=True) BATCH_SIZE = 200