提交 00776b16 编写于 作者: D dzhwinter 提交者: chengduo

fix memory opt skip set by name (#14774)

* random failed. rerun ci. test=develop

* windows failed. rerun ci. test=develop
上级 c4c5f0b8
......@@ -22,6 +22,15 @@ from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler import memory_optimize
def _get_vars(prog):
assert (isinstance(prog, Program))
all_vars = set()
for op in prog.global_block().ops:
all_vars.update(op.input_arg_names)
all_vars.update(op.output_arg_names)
return all_vars
class TestControlFlowGraph(unittest.TestCase):
def setUp(self):
program = Program()
......@@ -37,11 +46,11 @@ class TestControlFlowGraph(unittest.TestCase):
self.program = program
def test_control_flow_graph(self):
print("before optimization")
print(str(self.program))
result_program = memory_optimize(self.program)
print("after optimization")
print(str(result_program))
result_program = self.program.clone()
memory_optimize(self.program)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
class TestMemoryTranspiler2(unittest.TestCase):
......@@ -58,14 +67,22 @@ class TestMemoryTranspiler2(unittest.TestCase):
avg_cost = layers.mean(cost)
opt = optimizer.SGD(learning_rate=0.001)
opt.minimize(avg_cost)
self.skip_set = set([cost.name, fc.name])
self.program = program
def test_inplace_ops(self):
print("before optimization")
print(str(self.program))
result_program = memory_optimize(self.program)
print("after optimization")
print(str(result_program))
result_program = self.program.clone()
memory_optimize(self.program)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
def test_skip_opt(self):
result_program = self.program.clone()
memory_optimize(self.program, skip_opt_set=self.skip_set)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
class TestMemoryTranspiler3(unittest.TestCase):
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import six
from collections import defaultdict, MutableSet
from .. import core
from ... import compat as cpt
......@@ -470,8 +471,21 @@ def memory_optimize(input_program,
Returns:
None
"""
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
if level != 0 and level != 1:
raise ValueError("only support opt_level 0 or 1.")
if skip_opt_set is not None and not isinstance(skip_opt_set, set):
raise ValueError("only support skip_opt_set as set.")
global PRINT_LOG
PRINT_LOG = print_log
if skip_grads:
......@@ -486,6 +500,8 @@ def memory_optimize(input_program,
skip_opt_set = grad_set
else:
skip_opt_set.update(grad_set)
if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册