diff --git a/fluid/transformer/config.py b/fluid/transformer/config.py index ab52bbd1104322371371b92f45363a9eb222cf9d..091ea175291c56d63e1d8b42a874516d9733f1cf 100644 --- a/fluid/transformer/config.py +++ b/fluid/transformer/config.py @@ -12,6 +12,9 @@ class TrainTaskConfig(object): beta2 = 0.98 eps = 1e-9 + # the params for learning rate scheduling + warmup_steps = 4000 + class ModelHyperParams(object): # Dictionary size for source and target language. This model directly uses diff --git a/fluid/transformer/optim.py b/fluid/transformer/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..9905e6594a668b8e59fef1a4394714a6fcb8aeb6 --- /dev/null +++ b/fluid/transformer/optim.py @@ -0,0 +1,40 @@ +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers + + +class LearningRateScheduler(object): + """ + Wrapper for learning rate scheduling as described in the Transformer paper. + LearningRateScheduler adapts the learning rate externally and the adapted + learning rate will be feeded into the main_program as input data. + """ + + def __init__(self, + d_model, + warmup_steps, + place, + learning_rate=0.001, + current_steps=0, + name="learning_rate"): + self.current_steps = current_steps + self.warmup_steps = warmup_steps + self.d_model = d_model + self.learning_rate = layers.create_global_var( + name=name, + shape=[1], + value=float(learning_rate), + dtype="float32", + persistable=True) + self.place = place + + def update_learning_rate(self, data_input): + self.current_steps += 1 + lr_value = np.power(self.d_model, -0.5) * np.min([ + np.power(self.current_steps, -0.5), + np.power(self.warmup_steps, -1.5) * self.current_steps + ]) + lr_tensor = fluid.LoDTensor() + lr_tensor.set(np.array([lr_value], dtype="float32"), self.place) + data_input[self.learning_rate.name] = lr_tensor diff --git a/fluid/transformer/train.py b/fluid/transformer/train.py index 1818beedadd9da5c1e835debed900a0bb68f14fb..b841ef4621d91e07f9d93e87a795c4605e7f30bc 100644 --- a/fluid/transformer/train.py +++ b/fluid/transformer/train.py @@ -4,6 +4,7 @@ import paddle.v2 as paddle import paddle.fluid as fluid from model import transformer, position_encoding_init +from optim import LearningRateScheduler from config import TrainTaskConfig, ModelHyperParams, \ pos_enc_param_names, input_data_names @@ -88,6 +89,9 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, def main(): + place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + cost = transformer( ModelHyperParams.src_vocab_size + 1, ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1, @@ -97,8 +101,11 @@ def main(): ModelHyperParams.dropout, ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) + lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, + TrainTaskConfig.warmup_steps, place, + TrainTaskConfig.learning_rate) optimizer = fluid.optimizer.Adam( - learning_rate=TrainTaskConfig.learning_rate, + learning_rate=lr_scheduler.learning_rate, beta1=TrainTaskConfig.beta1, beta2=TrainTaskConfig.beta2, epsilon=TrainTaskConfig.eps) @@ -111,9 +118,6 @@ def main(): buf_size=51200), batch_size=TrainTaskConfig.batch_size) - place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() - exe = fluid.Executor(place) - # Initialize the parameters. exe.run(fluid.framework.default_startup_program()) for pos_enc_param_name in pos_enc_param_names: @@ -125,10 +129,15 @@ def main(): for pass_id in xrange(TrainTaskConfig.pass_num): for batch_id, data in enumerate(train_data()): + # The current program desc is coupled with batch_size, thus all + # mini-batches must have the same number of instances currently. + if len(data) != TrainTaskConfig.batch_size: + continue data_input = prepare_batch_input( data, input_data_names, ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length, ModelHyperParams.n_head, place) + lr_scheduler.update_learning_rate(data_input) outs = exe.run(fluid.framework.default_main_program(), feed=data_input, fetch_list=[cost])