提交 142fac18 编写于 作者: Q qiaolongfei

add print_log to memory_optimize

上级 84aea8a8
...@@ -31,6 +31,8 @@ dtype_to_size = { ...@@ -31,6 +31,8 @@ dtype_to_size = {
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"] sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
PRINT_LOG = False
class ControlFlowGraph(object): class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num, skip_opt): def __init__(self, Program, ops, forward_num, skip_opt):
...@@ -170,7 +172,7 @@ class ControlFlowGraph(object): ...@@ -170,7 +172,7 @@ class ControlFlowGraph(object):
block_desc, cache_var, is_forward).dtype() block_desc, cache_var, is_forward).dtype()
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype] # TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype] # and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype: if x_dtype == cache_dtype and PRINT_LOG:
print(("Hit Cache !!!! cache pool index " print(("Hit Cache !!!! cache pool index "
"is %d, var name is %s, " "is %d, var name is %s, "
"cached var name is %s, " "cached var name is %s, "
...@@ -277,7 +279,9 @@ def _get_cfgs(input_program): ...@@ -277,7 +279,9 @@ def _get_cfgs(input_program):
return cfgs return cfgs
def memory_optimize(input_program): def memory_optimize(input_program, print_log=False):
global PRINT_LOG
PRINT_LOG = print_log
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize() cfg.memory_optimize()
...@@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost) ...@@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program(), print_log=True)
BATCH_SIZE = 200 BATCH_SIZE = 200
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册