提交 4b8ba0ef 编写于 作者: G guosheng

Add learning rate scheduling in Transformer

上级 2c00d0f9
...@@ -12,6 +12,9 @@ class TrainTaskConfig(object): ...@@ -12,6 +12,9 @@ class TrainTaskConfig(object):
beta2 = 0.98 beta2 = 0.98
eps = 1e-9 eps = 1e-9
# the params for learning rate scheduling
warmup_steps = 4000
class ModelHyperParams(object): class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses # Dictionary size for source and target language. This model directly uses
......
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
class LearningRateScheduler(object):
"""
Wrapper for learning rate scheduling as described in the Transformer paper.
LearningRateScheduler adapts the learning rate externally and the adapted
learning rate will be feeded into the main_program as input data.
"""
def __init__(self,
d_model,
warmup_steps,
place,
learning_rate=0.001,
current_steps=0,
name="learning_rate"):
self.current_steps = current_steps
self.warmup_steps = warmup_steps
self.d_model = d_model
self.learning_rate = layers.create_global_var(
name=name,
shape=[1],
value=float(learning_rate),
dtype="float32",
persistable=True)
self.place = place
def update_learning_rate(self, data_input):
self.current_steps += 1
lr_value = np.power(self.d_model, -0.5) * np.min([
np.power(self.current_steps, -0.5),
np.power(self.warmup_steps, -1.5) * self.current_steps
])
lr_tensor = fluid.LoDTensor()
lr_tensor.set(np.array([lr_value], dtype="float32"), self.place)
data_input[self.learning_rate.name] = lr_tensor
...@@ -4,6 +4,7 @@ import paddle.v2 as paddle ...@@ -4,6 +4,7 @@ import paddle.v2 as paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from model import transformer, position_encoding_init from model import transformer, position_encoding_init
from optim import LearningRateScheduler
from config import TrainTaskConfig, ModelHyperParams, \ from config import TrainTaskConfig, ModelHyperParams, \
pos_enc_param_names, input_data_names pos_enc_param_names, input_data_names
...@@ -88,6 +89,9 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -88,6 +89,9 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
def main(): def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
cost = transformer( cost = transformer(
ModelHyperParams.src_vocab_size + 1, ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
...@@ -97,8 +101,11 @@ def main(): ...@@ -97,8 +101,11 @@ def main():
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx, ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place,
TrainTaskConfig.learning_rate)
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=TrainTaskConfig.learning_rate, learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1, beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2, beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps) epsilon=TrainTaskConfig.eps)
...@@ -111,9 +118,6 @@ def main(): ...@@ -111,9 +118,6 @@ def main():
buf_size=51200), buf_size=51200),
batch_size=TrainTaskConfig.batch_size) batch_size=TrainTaskConfig.batch_size)
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Initialize the parameters. # Initialize the parameters.
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
for pos_enc_param_name in pos_enc_param_names: for pos_enc_param_name in pos_enc_param_names:
...@@ -125,10 +129,15 @@ def main(): ...@@ -125,10 +129,15 @@ def main():
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input( data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx, data, input_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place) ModelHyperParams.n_head, place)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(), outs = exe.run(fluid.framework.default_main_program(),
feed=data_input, feed=data_input,
fetch_list=[cost]) fetch_list=[cost])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册