未验证 提交 863f80e3 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14131 from panyx0718/cherry-pick-1.1

Merge pull request #14055 from dzhwinter/fix/mem_opt
...@@ -171,7 +171,7 @@ class ControlFlowGraph(object): ...@@ -171,7 +171,7 @@ class ControlFlowGraph(object):
self._live_out[i] |= self._live_in[s] self._live_out[i] |= self._live_in[s]
self._live_in[i] = self._uses[i] | ( self._live_in[i] = self._uses[i] | (
self._live_out[i] - self._defs[i]) self._live_out[i] - self._defs[i])
if live_in[i] != self._live_in[i]: if live_in[i] != set(self._live_in[i]):
for d in self._presuccessors[i]: for d in self._presuccessors[i]:
worklist.append(d) worklist.append(d)
...@@ -321,8 +321,7 @@ class ControlFlowGraph(object): ...@@ -321,8 +321,7 @@ class ControlFlowGraph(object):
if not compare_shape(x_shape, cache_shape, level): if not compare_shape(x_shape, cache_shape, level):
continue continue
# TODO(qijun): actually, we should compare # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
if x_dtype != cache_dtype: if x_dtype != cache_dtype:
continue continue
...@@ -487,7 +486,6 @@ def memory_optimize(input_program, ...@@ -487,7 +486,6 @@ def memory_optimize(input_program,
skip_opt_set = grad_set skip_opt_set = grad_set
else: else:
skip_opt_set.update(grad_set) skip_opt_set.update(grad_set)
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册