提交 4fd83a8d 编写于 作者: G guosheng

Fix the learning rate error caused by args.local in Transformer

上级 0b48d785
......@@ -377,6 +377,8 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for batch_id, data in enumerate(train_data()):
feed_list = []
total_num_token = 0
if args.local:
lr_rate = lr_scheduler.update_learning_rate()
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
......@@ -388,7 +390,6 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
feed_kv_pairs = data_input_dict.items() + util_input_dict.items(
)
if args.local:
lr_rate = lr_scheduler.update_learning_rate()
feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate
}.items()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册