未验证 提交 7e6a2190 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix UT test_lr_scheduler random fail (#39254)

上级 0e235e58
...@@ -323,15 +323,15 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): ...@@ -323,15 +323,15 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
class TestLRScheduler(unittest.TestCase): class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place): def _test_static(self, python_func, paddle_api, kwarg, place):
scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam(learning_rate=scheduler)
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
start_prog = paddle.static.Program() start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog): with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[3, 4, 5]) x = paddle.static.data(name='x', shape=[3, 4, 5])
y = paddle.static.data(name='y', shape=[3, 4, 5]) loss = paddle.mean(x)
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam(learning_rate=scheduler)
adam.minimize(loss) adam.minimize(loss)
lr_var = adam._global_learning_rate() lr_var = adam._global_learning_rate()
test_prog = main_prog.clone() test_prog = main_prog.clone()
...@@ -339,14 +339,12 @@ class TestLRScheduler(unittest.TestCase): ...@@ -339,14 +339,12 @@ class TestLRScheduler(unittest.TestCase):
num = 0 num = 0
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(start_prog) exe.run(start_prog)
for epoch in range(5): for epoch in range(5):
for batch_id in range(2): for batch_id in range(2):
out = exe.run( out = exe.run(
main_prog, main_prog,
feed={ feed={'x': np.random.randn(3, 4, 5).astype('float32')},
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name) fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg))) self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step() scheduler.step()
...@@ -356,10 +354,7 @@ class TestLRScheduler(unittest.TestCase): ...@@ -356,10 +354,7 @@ class TestLRScheduler(unittest.TestCase):
for batch_id in range(2): for batch_id in range(2):
out = exe.run( out = exe.run(
test_prog, test_prog,
feed={ feed={'x': np.random.randn(3, 4, 5).astype('float32')},
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
fetch_list=lr_var.name) fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg))) self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step() scheduler.step()
...@@ -372,13 +367,12 @@ class TestLRScheduler(unittest.TestCase): ...@@ -372,13 +367,12 @@ class TestLRScheduler(unittest.TestCase):
for epoch in range(5): for epoch in range(5):
python_result = python_func(num, **kwarg) python_result = python_func(num, **kwarg)
for batch_id in range(2): for batch_id in range(2):
_ = exe.run( _ = exe.run(compiled_train_prog,
compiled_train_prog, feed={
feed={ 'x':
'x': np.random.randn(12, 4, 5).astype('float32'), np.random.randn(12, 4, 5).astype('float32')
'y': np.random.randn(12, 4, 5).astype('float32') },
}, fetch_list=lr_var.name)
fetch_list=lr_var.name)
scopes = compiled_train_prog._executor.local_scopes() scopes = compiled_train_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor()) out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result)) self.assertEqual(out, np.array(python_result))
...@@ -399,13 +393,12 @@ class TestLRScheduler(unittest.TestCase): ...@@ -399,13 +393,12 @@ class TestLRScheduler(unittest.TestCase):
for epoch in range(5): for epoch in range(5):
python_result = python_func(num, **kwarg) python_result = python_func(num, **kwarg)
for batch_id in range(2): for batch_id in range(2):
_ = exe.run( _ = exe.run(compiled_test_prog,
compiled_test_prog, feed={
feed={ 'x':
'x': np.random.randn(12, 4, 5).astype('float32'), np.random.randn(12, 4, 5).astype('float32')
'y': np.random.randn(12, 4, 5).astype('float32') },
}, fetch_list=lr_var.name)
fetch_list=lr_var.name)
scopes = compiled_test_prog._executor.local_scopes() scopes = compiled_test_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor()) out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result)) self.assertEqual(out, np.array(python_result))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册