提交 00776b16 编写于 作者: D dzhwinter 提交者: chengduo

fix memory opt skip set by name (#14774)

* random failed. rerun ci. test=develop

* windows failed. rerun ci. test=develop
上级 c4c5f0b8
...@@ -22,6 +22,15 @@ from paddle.fluid.framework import Program, program_guard ...@@ -22,6 +22,15 @@ from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler import memory_optimize from paddle.fluid.transpiler import memory_optimize
def _get_vars(prog):
assert (isinstance(prog, Program))
all_vars = set()
for op in prog.global_block().ops:
all_vars.update(op.input_arg_names)
all_vars.update(op.output_arg_names)
return all_vars
class TestControlFlowGraph(unittest.TestCase): class TestControlFlowGraph(unittest.TestCase):
def setUp(self): def setUp(self):
program = Program() program = Program()
...@@ -37,11 +46,11 @@ class TestControlFlowGraph(unittest.TestCase): ...@@ -37,11 +46,11 @@ class TestControlFlowGraph(unittest.TestCase):
self.program = program self.program = program
def test_control_flow_graph(self): def test_control_flow_graph(self):
print("before optimization") result_program = self.program.clone()
print(str(self.program)) memory_optimize(self.program)
result_program = memory_optimize(self.program) old_vars = _get_vars(self.program)
print("after optimization") new_vars = _get_vars(result_program)
print(str(result_program)) self.assertTrue(old_vars != new_vars)
class TestMemoryTranspiler2(unittest.TestCase): class TestMemoryTranspiler2(unittest.TestCase):
...@@ -58,14 +67,22 @@ class TestMemoryTranspiler2(unittest.TestCase): ...@@ -58,14 +67,22 @@ class TestMemoryTranspiler2(unittest.TestCase):
avg_cost = layers.mean(cost) avg_cost = layers.mean(cost)
opt = optimizer.SGD(learning_rate=0.001) opt = optimizer.SGD(learning_rate=0.001)
opt.minimize(avg_cost) opt.minimize(avg_cost)
self.skip_set = set([cost.name, fc.name])
self.program = program self.program = program
def test_inplace_ops(self): def test_inplace_ops(self):
print("before optimization") result_program = self.program.clone()
print(str(self.program)) memory_optimize(self.program)
result_program = memory_optimize(self.program) old_vars = _get_vars(self.program)
print("after optimization") new_vars = _get_vars(result_program)
print(str(result_program)) self.assertTrue(old_vars != new_vars)
def test_skip_opt(self):
result_program = self.program.clone()
memory_optimize(self.program, skip_opt_set=self.skip_set)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
class TestMemoryTranspiler3(unittest.TestCase): class TestMemoryTranspiler3(unittest.TestCase):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import six
from collections import defaultdict, MutableSet from collections import defaultdict, MutableSet
from .. import core from .. import core
from ... import compat as cpt from ... import compat as cpt
...@@ -470,8 +471,21 @@ def memory_optimize(input_program, ...@@ -470,8 +471,21 @@ def memory_optimize(input_program,
Returns: Returns:
None None
""" """
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
if level != 0 and level != 1: if level != 0 and level != 1:
raise ValueError("only support opt_level 0 or 1.") raise ValueError("only support opt_level 0 or 1.")
if skip_opt_set is not None and not isinstance(skip_opt_set, set):
raise ValueError("only support skip_opt_set as set.")
global PRINT_LOG global PRINT_LOG
PRINT_LOG = print_log PRINT_LOG = print_log
if skip_grads: if skip_grads:
...@@ -486,6 +500,8 @@ def memory_optimize(input_program, ...@@ -486,6 +500,8 @@ def memory_optimize(input_program,
skip_opt_set = grad_set skip_opt_set = grad_set
else: else:
skip_opt_set.update(grad_set) skip_opt_set.update(grad_set)
if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册