提交 5212b2a9 编写于 作者: D dzhwinter

"rerun"

上级 2617ac9d
...@@ -47,7 +47,6 @@ PRINT_LOG = False ...@@ -47,7 +47,6 @@ PRINT_LOG = False
class ControlFlowGraph(object): class ControlFlowGraph(object):
def __init__(self, program, ops, forward_num, skip_opt): def __init__(self, program, ops, forward_num, skip_opt):
self._program = program self._program = program
self._dup_program = program.clone()
self._ops = ops self._ops = ops
self._forward_num = forward_num self._forward_num = forward_num
self._successors = defaultdict(set) self._successors = defaultdict(set)
...@@ -230,8 +229,6 @@ class ControlFlowGraph(object): ...@@ -230,8 +229,6 @@ class ControlFlowGraph(object):
for x in defs_can_optimize for x in defs_can_optimize
] ]
for x, x_shape in out_pair: for x, x_shape in out_pair:
if (x, x_shape) in self.pool:
raise ValueError("x in pool, %s, %s" % (x, x_shape))
# If x is both in uses and defs, it can not be optimized! # If x is both in uses and defs, it can not be optimized!
if x in self._uses[i]: if x in self._uses[i]:
continue continue
...@@ -239,14 +236,15 @@ class ControlFlowGraph(object): ...@@ -239,14 +236,15 @@ class ControlFlowGraph(object):
cache_var = cache_pair[0] cache_var = cache_pair[0]
cache_shape = cache_pair[1] cache_shape = cache_pair[1]
if not self._has_var(block_desc, cache_var, is_forward): if not self._has_var(block_desc, cache_var, is_forward):
raise ValueError("cache", if PRINT_LOG:
cpt.to_text(cache_var), print("cache %s not exists!" %
" Not exists!") (cpt.to_text(cache_var)))
continue
if x == cache_var: if x == cache_var:
raise ValueError("x : ", if PRINT_LOG:
cpt.to_text(x), " cache : ", print("x : ", cpt.to_text(x), " cache : ",
cpt.to_text(cache_var), cpt.to_text(cache_var), " is same var!")
" is same var!") break
x_dtype = self._find_var(block_desc, x, x_dtype = self._find_var(block_desc, x,
is_forward).dtype() is_forward).dtype()
...@@ -383,10 +381,13 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0): ...@@ -383,10 +381,13 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
Note: it doesn't not support subblock nested in subblock. Note: it doesn't not support subblock nested in subblock.
:param input_program(str): Input Program Args:
:param print_log: whether to print debug log. input_program(str): Input Program
:param level: If level=0, reuse if the shape is completely equal, o skip_opt_set(set): vars wil be skipped in memory optimze
:return: print_log(bool): whether to print debug log.
level(int): If level=0, reuse if the shape is completely equal, o
Returns:
None
""" """
if level != 0 and level != 1: if level != 0 and level != 1:
raise ValueError("only support opt_level 0 or 1.") raise ValueError("only support opt_level 0 or 1.")
...@@ -407,6 +408,9 @@ def release_memory(input_program, skip_opt_set=None): ...@@ -407,6 +408,9 @@ def release_memory(input_program, skip_opt_set=None):
Args: Args:
input_program(Program): The program will be inserted :code:`delete_op`. input_program(Program): The program will be inserted :code:`delete_op`.
skip_opt_set(set): vars wil be skipped in memory optimze
Returns:
None
""" """
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册