From 4fd83a8d9e4d6a70d5109b18497c665614be9ec8 Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 24 Jul 2018 00:31:25 +0800 Subject: [PATCH] Fix the learning rate error caused by args.local in Transformer --- fluid/neural_machine_translation/transformer/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index d2cd5a18..a9d543ef 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -377,6 +377,8 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, for batch_id, data in enumerate(train_data()): feed_list = [] total_num_token = 0 + if args.local: + lr_rate = lr_scheduler.update_learning_rate() for place_id, data_buffer in enumerate( split_data( data, num_part=dev_count)): @@ -388,7 +390,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, feed_kv_pairs = data_input_dict.items() + util_input_dict.items( ) if args.local: - lr_rate = lr_scheduler.update_learning_rate() feed_kv_pairs += { lr_scheduler.learning_rate.name: lr_rate }.items() -- GitLab