未验证 提交 02b4e989 编写于 作者: L Leo Chen 提交者: GitHub

fix pruned_program_cache_key of Operator (#23594)

* fix init_gflags with 'python -c', test=develop

* fix pruned_program_cache_key of Operator, test=develop
上级 2c4b57e9
...@@ -354,7 +354,7 @@ def _to_name_str(var): ...@@ -354,7 +354,7 @@ def _to_name_str(var):
elif isinstance(var, six.string_types): elif isinstance(var, six.string_types):
return str(var) return str(var)
elif isinstance(var, Operator): elif isinstance(var, Operator):
return var.desc.type() return str(id(var))
else: else:
raise TypeError(str(var) + " should be Variable, Operator or str") raise TypeError(str(var) + " should be Variable, Operator or str")
......
...@@ -323,7 +323,7 @@ class TestExecutorRunAutoPrune(unittest.TestCase): ...@@ -323,7 +323,7 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
def test_prune_with_cache_program(self): def test_prune_with_cache_program(self):
''' '''
When use_prune=True and use_program_cache=True, Executor should cache the pruned program. When use_prune=True, Executor should cache the pruned program.
If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program, If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
and needn't to call _prune_program() to prune the program. and needn't to call _prune_program() to prune the program.
In this test, we hack the Executor._prune_program with a mock function which do nothing but increase In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
...@@ -350,16 +350,68 @@ class TestExecutorRunAutoPrune(unittest.TestCase): ...@@ -350,16 +350,68 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
feed={'x': x_np, feed={'x': x_np,
'label': label_np}, 'label': label_np},
fetch_list=[loss1.name], fetch_list=[loss1.name],
use_prune=True, use_prune=True)
use_program_cache=True)
if i == 0: if i == 0:
self.assertEqual(exe.prune_called_times, 1) self.assertEqual(exe.prune_called_times, 1)
else: else:
self.assertEqual(exe.prune_called_times, 1) self.assertEqual(exe.prune_called_times, 1)
def test_prune_with_cache_program2(self):
'''
When use_prune=True, Executor should cache the pruned program.
If the only difference in fetch_list is optimize_ops during multiple runs,
the cache_keys should be different and get different pruned program.
'''
with _mock_guard(mock):
exe = fluid.Executor(fluid.CPUPlace())
exe.prune_called_times = 0
program = framework.Program()
startup_program = framework.Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(program, startup_program):
(x1, x2, y1, y2, label, loss1, loss2, w1_param_attrs,
w2_param_attrs) = self.net2()
adam_optimizer1 = fluid.optimizer.AdamOptimizer(
learning_rate=0.5)
train1 = adam_optimizer1.minimize(loss1)
adam_optimizer2 = fluid.optimizer.AdamOptimizer(
learning_rate=0.5)
train2 = adam_optimizer2.minimize(loss2)
exe.run(startup_program)
x_np = np.random.random(size=(10, 2)).astype('float32')
label_np = np.random.randint(
1, size=(10, 1)).astype('int64')
for i in range(10):
if i % 2:
res = exe.run(program,
feed={
'x1': x_np,
'x2': x_np,
'label': label_np
},
fetch_list=[loss1, loss2, train1],
use_prune=True)
else:
res = exe.run(program,
feed={
'x1': x_np,
'x2': x_np,
'label': label_np
},
fetch_list=[loss1, loss2, train2],
use_prune=True)
if i == 0:
self.assertEqual(exe.prune_called_times, 1)
elif i == 1:
self.assertEqual(exe.prune_called_times, 2)
else:
self.assertEqual(exe.prune_called_times, 2)
def test_prune_with_cache_compiled_program(self): def test_prune_with_cache_compiled_program(self):
''' '''
When use_prune=True and use_program_cache=True, Executor should cache the pruned program. When use_prune=True, Executor should cache the pruned program.
If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program, If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
and needn't to call _prune_program() to prune the program. and needn't to call _prune_program() to prune the program.
In this test, we hack the Executor._prune_program with a mock function which do nothing but increase In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
...@@ -389,8 +441,7 @@ class TestExecutorRunAutoPrune(unittest.TestCase): ...@@ -389,8 +441,7 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
feed={'x': x_np, feed={'x': x_np,
'label': label_np}, 'label': label_np},
fetch_list=[loss1.name], fetch_list=[loss1.name],
use_prune=True, use_prune=True)
use_program_cache=True)
if i == 0: if i == 0:
self.assertEqual(exe.prune_called_times, 1) self.assertEqual(exe.prune_called_times, 1)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册