未验证 提交 7e6a2190 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix UT test_lr_scheduler random fail (#39254)

上级 0e235e58
......@@ -323,15 +323,15 @@ def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
class TestLRScheduler(unittest.TestCase):
def _test_static(self, python_func, paddle_api, kwarg, place):
scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam(learning_rate=scheduler)
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[3, 4, 5])
y = paddle.static.data(name='y', shape=[3, 4, 5])
z = paddle.static.nn.fc(x, 100)
loss = paddle.mean(z)
scheduler = paddle_api(**kwarg)
adam = paddle.optimizer.Adam(learning_rate=scheduler)
loss = paddle.mean(x)
adam.minimize(loss)
lr_var = adam._global_learning_rate()
test_prog = main_prog.clone()
......@@ -339,14 +339,12 @@ class TestLRScheduler(unittest.TestCase):
num = 0
exe = paddle.static.Executor(place)
exe.run(start_prog)
for epoch in range(5):
for batch_id in range(2):
out = exe.run(
main_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
feed={'x': np.random.randn(3, 4, 5).astype('float32')},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
......@@ -356,10 +354,7 @@ class TestLRScheduler(unittest.TestCase):
for batch_id in range(2):
out = exe.run(
test_prog,
feed={
'x': np.random.randn(3, 4, 5).astype('float32'),
'y': np.random.randn(3, 4, 5).astype('float32')
},
feed={'x': np.random.randn(3, 4, 5).astype('float32')},
fetch_list=lr_var.name)
self.assertEqual(out, np.array(python_func(num, **kwarg)))
scheduler.step()
......@@ -372,13 +367,12 @@ class TestLRScheduler(unittest.TestCase):
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_train_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
_ = exe.run(compiled_train_prog,
feed={
'x':
np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_train_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
......@@ -399,13 +393,12 @@ class TestLRScheduler(unittest.TestCase):
for epoch in range(5):
python_result = python_func(num, **kwarg)
for batch_id in range(2):
_ = exe.run(
compiled_test_prog,
feed={
'x': np.random.randn(12, 4, 5).astype('float32'),
'y': np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
_ = exe.run(compiled_test_prog,
feed={
'x':
np.random.randn(12, 4, 5).astype('float32')
},
fetch_list=lr_var.name)
scopes = compiled_test_prog._executor.local_scopes()
out = np.array(scopes[0].var(lr_var.name).get_tensor())
self.assertEqual(out, np.array(python_result))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册