From 48f5612ba00077dfd009c85cdbed975c00228a1d Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 16 Apr 2018 12:36:01 +0000 Subject: [PATCH] clearnup --- .../transformer_nist_base/nmt_fluid.py | 22 ++++++++++++++++++- .../transformer_nist_base/optim.py | 1 + 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py index 12d9ba1e..3a4721c9 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py +++ b/fluid/neural_machine_translation/transformer_nist_base/nmt_fluid.py @@ -173,6 +173,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, return input_dict +def get_var(name,value): + return fluid.layers.create_global_var( + name=name, + shape=[1], + value=float(value), + dtype="float32", + persistable=True) + def main(): place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace( args.device_id) @@ -185,11 +193,21 @@ def main(): ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_inner_hid, ModelHyperParams.dropout) + ''' lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, TrainTaskConfig.warmup_steps, place, TrainTaskConfig.learning_rate) + ''' + + warmup_steps = get_var("warmup_steps", value=TrainTaskConfig.warmup_steps) + d_model = get_var("d_model", value=ModelHyperParams.d_model) + + lr_decay = fluid.layers\ + .learning_rate_scheduler\ + .nmt_nist_decay(d_model, warmup_steps) + optimizer = fluid.optimizer.Adam( - learning_rate=lr_scheduler.learning_rate, + learning_rate = lr_decay, beta1=TrainTaskConfig.beta1, beta2=TrainTaskConfig.beta2, epsilon=TrainTaskConfig.eps) @@ -251,8 +269,10 @@ def main(): label_data_names, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head, ModelHyperParams.d_model) + ''' if args.local: lr_scheduler.update_learning_rate(data_input) + ''' outs = exe.run(trainer_prog, feed=data_input, fetch_list=[sum_cost, avg_cost], diff --git a/fluid/neural_machine_translation/transformer_nist_base/optim.py b/fluid/neural_machine_translation/transformer_nist_base/optim.py index 2828b4bc..6a6bc129 100644 --- a/fluid/neural_machine_translation/transformer_nist_base/optim.py +++ b/fluid/neural_machine_translation/transformer_nist_base/optim.py @@ -39,3 +39,4 @@ class LearningRateScheduler(object): lr_tensor.set(np.array([lr_value], dtype="float32"), self.place) data_input[self.learning_rate.name] = lr_tensor + layers.Print(self.learning_rate) -- GitLab