提交 a8612adb 编写于 作者: Y Yancey1989

fix lr scale test=develop

上级 86bb5838
......@@ -167,17 +167,13 @@ def cosine_decay(learning_rate, step_each_epoch, epochs=120):
return decayed_lr
def optimizer(learning_rate=0.01, lr_scale=1.0):
def _opt():
return fluid.optimizer.Momentum(
def optimizer(learning_rate=0.01):
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay(
learning_rate=learning_rate / lr_scale,
step_each_epoch=2,
epochs=1),
learning_rate=learning_rate, step_each_epoch=2, epochs=1),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
return _opt
return optimizer
class TestResnet(TestParallelExecutorBase):
......@@ -220,7 +216,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=False,
optimizer=optimizer())
optimizer=optimizer)
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
......@@ -229,7 +225,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=True,
optimizer=optimizer())
optimizer=optimizer)
for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
......@@ -247,7 +243,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=False,
optimizer=optimizer(),
optimizer=optimizer,
enable_sequential_execution=True)
reduce_first_loss_seq, reduce_last_loss_seq = self.check_network_convergence(
......@@ -258,7 +254,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=True,
optimizer=optimizer(),
optimizer=optimizer,
enable_sequential_execution=True)
for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq):
......@@ -301,7 +297,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=use_reduce,
optimizer=optimizer(),
optimizer=optimizer,
use_parallel_executor=False,
use_parallel_graph=use_parallel_graph)
parallel_first_loss, parallel_last_loss = self.check_network_convergence(
......@@ -312,7 +308,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=use_reduce,
optimizer=optimizer(),
optimizer=optimizer,
use_parallel_graph=use_parallel_graph)
self.assertAlmostEquals(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册