未验证 提交 c89664e8 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #8831 from jacquesqiao/add-printlog-to-memoryoptimize

add print_log to memory_optimize
...@@ -31,6 +31,8 @@ dtype_to_size = { ...@@ -31,6 +31,8 @@ dtype_to_size = {
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"] sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
PRINT_LOG = False
class ControlFlowGraph(object): class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num, skip_opt): def __init__(self, Program, ops, forward_num, skip_opt):
...@@ -171,12 +173,14 @@ class ControlFlowGraph(object): ...@@ -171,12 +173,14 @@ class ControlFlowGraph(object):
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype] # TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype] # and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype: if x_dtype == cache_dtype:
print(("Hit Cache !!!! cache pool index " if PRINT_LOG:
"is %d, var name is %s, " print(
"cached var name is %s, " ("Hit Cache !!!! cache pool index "
"var shape is %s ") % "is %d, var name is %s, "
(index, x, cache_var, "cached var name is %s, "
str(cache_shape))) "var shape is %s ") %
(index, x, cache_var,
str(cache_shape)))
self.pool.pop(index) self.pool.pop(index)
if x == cache_var: if x == cache_var:
break break
...@@ -277,7 +281,9 @@ def _get_cfgs(input_program): ...@@ -277,7 +281,9 @@ def _get_cfgs(input_program):
return cfgs return cfgs
def memory_optimize(input_program): def memory_optimize(input_program, print_log=False):
global PRINT_LOG
PRINT_LOG = print_log
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize() cfg.memory_optimize()
...@@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost) ...@@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program(), print_log=True)
BATCH_SIZE = 200 BATCH_SIZE = 200
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册