提交 af492312 编写于 作者: S shibeiji

script update for bert

上级 ca6da675
...@@ -117,8 +117,7 @@ def run_pretrain(): ...@@ -117,8 +117,7 @@ def run_pretrain():
decay_params = list(filter(cfg.Lamb.decay_filter, params)) decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params}, {'params': other_params}]
{'order_params': params}]
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum': elif cfg.optimizer == 'Momentum':
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
...@@ -133,8 +132,7 @@ def run_pretrain(): ...@@ -133,8 +132,7 @@ def run_pretrain():
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params)) other_params = list(filter(lambda x: x not in decay_params, params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}, {'params': other_params, 'weight_decay': 0.0}]
{'order_params': params}]
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
else: else:
......
...@@ -22,7 +22,7 @@ from mindspore.ops import operations as P ...@@ -22,7 +22,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
...@@ -55,7 +55,7 @@ class BertFinetuneCell(nn.Cell): ...@@ -55,7 +55,7 @@ class BertFinetuneCell(nn.Cell):
super(BertFinetuneCell, self).__init__(auto_prefix=False) super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation('grad',
get_by_list=True, get_by_list=True,
...@@ -158,7 +158,7 @@ class BertSquadCell(nn.Cell): ...@@ -158,7 +158,7 @@ class BertSquadCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertSquadCell, self).__init__(auto_prefix=False) super(BertSquadCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.reducer_flag = False self.reducer_flag = False
......
...@@ -21,7 +21,7 @@ from mindspore.ops import operations as P ...@@ -21,7 +21,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
...@@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -270,7 +270,7 @@ class BertTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
...@@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -349,7 +349,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation('grad',
get_by_list=True, get_by_list=True,
......
...@@ -133,7 +133,10 @@ class BertLearningRate(LearningRateSchedule): ...@@ -133,7 +133,10 @@ class BertLearningRate(LearningRateSchedule):
""" """
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power): def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
super(BertLearningRate, self).__init__() super(BertLearningRate, self).__init__()
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps) self.warmup_flag = False
if warmup_steps > 0:
self.warmup_flag = True
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power) self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32)) self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
...@@ -142,8 +145,11 @@ class BertLearningRate(LearningRateSchedule): ...@@ -142,8 +145,11 @@ class BertLearningRate(LearningRateSchedule):
self.cast = P.Cast() self.cast = P.Cast()
def construct(self, global_step): def construct(self, global_step):
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
decay_lr = self.decay_lr(global_step) decay_lr = self.decay_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr if self.warmup_flag:
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
warmup_lr = self.warmup_lr(global_step)
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
else:
lr = decay_lr
return lr return lr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册