提交 48f5612b 编写于 作者: G gongweibao

clearnup

上级 6fc711cd
...@@ -173,6 +173,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx, ...@@ -173,6 +173,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
return input_dict return input_dict
def get_var(name,value):
return fluid.layers.create_global_var(
name=name,
shape=[1],
value=float(value),
dtype="float32",
persistable=True)
def main(): def main():
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace( place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(
args.device_id) args.device_id)
...@@ -185,11 +193,21 @@ def main(): ...@@ -185,11 +193,21 @@ def main():
ModelHyperParams.d_value, ModelHyperParams.d_model, ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout) ModelHyperParams.d_inner_hid, ModelHyperParams.dropout)
'''
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, place, TrainTaskConfig.warmup_steps, place,
TrainTaskConfig.learning_rate) TrainTaskConfig.learning_rate)
'''
warmup_steps = get_var("warmup_steps", value=TrainTaskConfig.warmup_steps)
d_model = get_var("d_model", value=ModelHyperParams.d_model)
lr_decay = fluid.layers\
.learning_rate_scheduler\
.nmt_nist_decay(d_model, warmup_steps)
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate, learning_rate = lr_decay,
beta1=TrainTaskConfig.beta1, beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2, beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps) epsilon=TrainTaskConfig.eps)
...@@ -251,8 +269,10 @@ def main(): ...@@ -251,8 +269,10 @@ def main():
label_data_names, ModelHyperParams.eos_idx, label_data_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head, ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model) ModelHyperParams.d_model)
'''
if args.local: if args.local:
lr_scheduler.update_learning_rate(data_input) lr_scheduler.update_learning_rate(data_input)
'''
outs = exe.run(trainer_prog, outs = exe.run(trainer_prog,
feed=data_input, feed=data_input,
fetch_list=[sum_cost, avg_cost], fetch_list=[sum_cost, avg_cost],
......
...@@ -39,3 +39,4 @@ class LearningRateScheduler(object): ...@@ -39,3 +39,4 @@ class LearningRateScheduler(object):
lr_tensor.set(np.array([lr_value], dtype="float32"), self.place) lr_tensor.set(np.array([lr_value], dtype="float32"), self.place)
data_input[self.learning_rate.name] = lr_tensor data_input[self.learning_rate.name] = lr_tensor
layers.Print(self.learning_rate)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册