提交 142fac18 编写于 作者: Q qiaolongfei

add print_log to memory_optimize

上级 84aea8a8
......@@ -31,6 +31,8 @@ dtype_to_size = {
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
PRINT_LOG = False
class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num, skip_opt):
......@@ -170,7 +172,7 @@ class ControlFlowGraph(object):
block_desc, cache_var, is_forward).dtype()
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype:
if x_dtype == cache_dtype and PRINT_LOG:
print(("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
......@@ -277,7 +279,9 @@ def _get_cfgs(input_program):
return cfgs
def memory_optimize(input_program):
def memory_optimize(input_program, print_log=False):
global PRINT_LOG
PRINT_LOG = print_log
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.memory_optimize()
......@@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
BATCH_SIZE = 200
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册