未验证 提交 c89664e8 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #8831 from jacquesqiao/add-printlog-to-memoryoptimize

add print_log to memory_optimize
......@@ -31,6 +31,8 @@ dtype_to_size = {
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
PRINT_LOG = False
class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num, skip_opt):
......@@ -171,7 +173,9 @@ class ControlFlowGraph(object):
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype:
print(("Hit Cache !!!! cache pool index "
if PRINT_LOG:
print(
("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s ") %
......@@ -277,7 +281,9 @@ def _get_cfgs(input_program):
return cfgs
def memory_optimize(input_program):
def memory_optimize(input_program, print_log=False):
global PRINT_LOG
PRINT_LOG = print_log
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.memory_optimize()
......@@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
BATCH_SIZE = 200
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册