提交 f9243e6a 编写于 作者: S songyouwei 提交者: hong

Optimizer init with parameters (#4137)

test=develop
上级 57ce8e2f
...@@ -47,7 +47,7 @@ lambda_identity = 0.5 ...@@ -47,7 +47,7 @@ lambda_identity = 0.5
tep_per_epoch = 2974 tep_per_epoch = 2974
def optimizer_setting(): def optimizer_setting(parameters):
lr = 0.0002 lr = 0.0002
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.piecewise_decay( learning_rate=fluid.layers.piecewise_decay(
...@@ -56,6 +56,7 @@ def optimizer_setting(): ...@@ -56,6 +56,7 @@ def optimizer_setting():
140 * step_per_epoch, 160 * step_per_epoch, 180 * step_per_epoch 140 * step_per_epoch, 160 * step_per_epoch, 180 * step_per_epoch
], ],
values=[lr, lr * 0.8, lr * 0.6, lr * 0.4, lr * 0.2, lr * 0.1]), values=[lr, lr * 0.8, lr * 0.6, lr * 0.4, lr * 0.2, lr * 0.1]),
parameter_list=parameters,
beta1=0.5) beta1=0.5)
return optimizer return optimizer
...@@ -88,9 +89,14 @@ def train(args): ...@@ -88,9 +89,14 @@ def train(args):
losses = [[], []] losses = [[], []]
t_time = 0 t_time = 0
optimizer1 = optimizer_setting()
optimizer2 = optimizer_setting() vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters()
optimizer3 = optimizer_setting() vars_da = cycle_gan.build_gen_discriminator_a.parameters()
vars_db = cycle_gan.build_gen_discriminator_b.parameters()
optimizer1 = optimizer_setting(vars_G)
optimizer2 = optimizer_setting(vars_da)
optimizer3 = optimizer_setting(vars_db)
for epoch in range(args.epoch): for epoch in range(args.epoch):
batch_id = 0 batch_id = 0
...@@ -114,9 +120,8 @@ def train(args): ...@@ -114,9 +120,8 @@ def train(args):
g_loss_out = g_loss.numpy() g_loss_out = g_loss.numpy()
g_loss.backward() g_loss.backward()
vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters()
optimizer1.minimize(g_loss, parameter_list=vars_G) optimizer1.minimize(g_loss)
cycle_gan.clear_gradients() cycle_gan.clear_gradients()
fake_pool_B = B_pool.pool_image(fake_B).numpy() fake_pool_B = B_pool.pool_image(fake_B).numpy()
...@@ -137,8 +142,7 @@ def train(args): ...@@ -137,8 +142,7 @@ def train(args):
d_loss_A = fluid.layers.reduce_mean(d_loss_A) d_loss_A = fluid.layers.reduce_mean(d_loss_A)
d_loss_A.backward() d_loss_A.backward()
vars_da = cycle_gan.build_gen_discriminator_a.parameters() optimizer2.minimize(d_loss_A)
optimizer2.minimize(d_loss_A, parameter_list=vars_da)
cycle_gan.clear_gradients() cycle_gan.clear_gradients()
# optimize the d_B network # optimize the d_B network
...@@ -150,8 +154,7 @@ def train(args): ...@@ -150,8 +154,7 @@ def train(args):
d_loss_B = fluid.layers.reduce_mean(d_loss_B) d_loss_B = fluid.layers.reduce_mean(d_loss_B)
d_loss_B.backward() d_loss_B.backward()
vars_db = cycle_gan.build_gen_discriminator_b.parameters() optimizer3.minimize(d_loss_B)
optimizer3.minimize(d_loss_B, parameter_list=vars_db)
cycle_gan.clear_gradients() cycle_gan.clear_gradients()
......
...@@ -187,7 +187,7 @@ def train_mnist(args): ...@@ -187,7 +187,7 @@ def train_mnist(args):
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context() strategy = fluid.dygraph.parallel.prepare_context()
mnist = MNIST() mnist = MNIST()
adam = AdamOptimizer(learning_rate=0.001) adam = AdamOptimizer(learning_rate=0.001, parameter_list=mnist.parameters())
if args.use_data_parallel: if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy) mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
......
...@@ -68,7 +68,7 @@ with fluid.dygraph.guard(): ...@@ -68,7 +68,7 @@ with fluid.dygraph.guard():
policy = Policy() policy = Policy()
eps = np.finfo(np.float32).eps.item() eps = np.finfo(np.float32).eps.item()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=3e-2) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=3e-2, parameter_list=policy.parameters())
def get_mean_and_std(values=[]): def get_mean_and_std(values=[]):
n = 0. n = 0.
......
...@@ -67,7 +67,7 @@ with fluid.dygraph.guard(): ...@@ -67,7 +67,7 @@ with fluid.dygraph.guard():
policy = Policy() policy = Policy()
eps = np.finfo(np.float32).eps.item() eps = np.finfo(np.float32).eps.item()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=1e-2) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=1e-2, parameter_list=policy.parameters())
def get_mean_and_std(values=[]): def get_mean_and_std(values=[]):
n = 0. n = 0.
......
...@@ -68,7 +68,7 @@ with fluid.dygraph.guard(): ...@@ -68,7 +68,7 @@ with fluid.dygraph.guard():
policy = Policy() policy = Policy()
eps = np.finfo(np.float32).eps.item() eps = np.finfo(np.float32).eps.item()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=3e-2) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=3e-2, parameter_list=policy.parameters())
def get_mean_and_std(values=[]): def get_mean_and_std(values=[]):
n = 0. n = 0.
......
...@@ -67,7 +67,7 @@ with fluid.dygraph.guard(): ...@@ -67,7 +67,7 @@ with fluid.dygraph.guard():
policy = Policy() policy = Policy()
eps = np.finfo(np.float32).eps.item() eps = np.finfo(np.float32).eps.item()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=1e-2) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=1e-2, parameter_list=policy.parameters())
def get_mean_and_std(values=[]): def get_mean_and_std(values=[]):
n = 0. n = 0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册