diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index d2cd5a185b2b4e2a35b5a485cd2be8b6e0f488de..a9d543ef21a6bb119381e311d03176b824a4357d 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -377,6 +377,8 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, for batch_id, data in enumerate(train_data()): feed_list = [] total_num_token = 0 + if args.local: + lr_rate = lr_scheduler.update_learning_rate() for place_id, data_buffer in enumerate( split_data( data, num_part=dev_count)): @@ -388,7 +390,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, feed_kv_pairs = data_input_dict.items() + util_input_dict.items( ) if args.local: - lr_rate = lr_scheduler.update_learning_rate() feed_kv_pairs += { lr_scheduler.learning_rate.name: lr_rate }.items()