From eb3f70a0c7043e951e9a943e713cb71857aef526 Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Wed, 6 May 2020 17:34:27 +0800 Subject: [PATCH] add warmup_steps in AdamWeightDecayDynamicLR optimizer --- mindspore/nn/optim/adam.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 1a386556d..6142b126c 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -327,12 +327,17 @@ class AdamWeightDecayDynamicLR(Optimizer): beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name, + warmup_steps=0): super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") + self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) + self.warmup_flag = False + if warmup_steps > 0: + self.warmup_flag = True self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32)) self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32)) self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32)) @@ -348,12 +353,20 @@ class AdamWeightDecayDynamicLR(Optimizer): self.hyper_map = C.HyperMap() self.min = P.Minimum() self.pow = P.Pow() + self.greater = P.Greater() self.one = Tensor(np.array([1.0]).astype(np.float32)) + self.cast = P.Cast() + self.start_learning_rate = Tensor(np.array([learning_rate]).astype(np.float32)) def construct(self, gradients): step = self.min(self.global_step, self.decay_steps) p = step / self.decay_steps lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate + if self.warmup_flag: + warmup_percent = self.global_step / self.warmup_steps + warmup_lr = self.start_learning_rate * warmup_percent + is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32) + lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay_tensor), self.params, self.moments1, self.moments2, gradients, self.decay_flag) -- GitLab