提交 ac958362 编写于 作者: C chenhaozhe

add global norm in bert

上级 7371cedd
...@@ -178,12 +178,14 @@ def run_pretrain(): ...@@ -178,12 +178,14 @@ def run_pretrain():
if args_opt.accumulation_steps <= 1: if args_opt.accumulation_steps <= 1:
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell) scale_update_cell=update_cell,
enable_global_norm=cfg.enable_global_norm)
else: else:
accumulation_steps = args_opt.accumulation_steps accumulation_steps = args_opt.accumulation_steps
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell, scale_update_cell=update_cell,
accumulation_steps=accumulation_steps) accumulation_steps=accumulation_steps,
enable_global_norm=cfg.enable_global_norm)
else: else:
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
......
...@@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size ...@@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from mindspore.ops import _selected_ops from mindspore.ops import _selected_ops
from .bert_model import BertModel from .bert_model import BertModel
from .utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0 GRADIENT_CLIP_VALUE = 1.0
...@@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
optimizer (Optimizer): Optimizer for updating the weights. optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None. scale_update_cell (Cell): Cell to do the loss scale. Default: None.
""" """
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.enable_global_norm = enable_global_norm
self.grad = C.GradOperation(get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
...@@ -419,6 +421,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -419,6 +421,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
# apply grad reducer on grads # apply grad reducer on grads
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
if self.enable_global_norm:
grads = ClipByGlobalNorm()(grads)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init) self.get_status(init)
flag_sum = self.reduce_sum(init, (0,)) flag_sum = self.reduce_sum(init, (0,))
...@@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): ...@@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1. batch_size * accumulation_steps. Default: 1.
""" """
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1): def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32)) self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32)) self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
...@@ -582,6 +588,9 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): ...@@ -582,6 +588,9 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
grads = self.grad_reducer(self.accu_grads) grads = self.grad_reducer(self.accu_grads)
scaling = scaling_sens * self.degree * self.accumulation_steps scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads) grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
if self.enable_global_norm:
grads = ClipByGlobalNorm()(grad)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
accu_overflow = self.overflow_reducer(accu_overflow) accu_overflow = self.overflow_reducer(accu_overflow)
F.control_depend(grads, accu_overflow) F.control_depend(grads, accu_overflow)
......
...@@ -24,6 +24,7 @@ cfg = edict({ ...@@ -24,6 +24,7 @@ cfg = edict({
'scale_factor': 2, 'scale_factor': 2,
'scale_window': 1000, 'scale_window': 1000,
'optimizer': 'Lamb', 'optimizer': 'Lamb',
'enable_global_norm': False,
'AdamWeightDecay': edict({ 'AdamWeightDecay': edict({
'learning_rate': 3e-5, 'learning_rate': 3e-5,
'end_learning_rate': 0.0, 'end_learning_rate': 0.0,
...@@ -115,6 +116,5 @@ if cfg.bert_network == 'large': ...@@ -115,6 +116,5 @@ if cfg.bert_network == 'large':
input_mask_from_dataset=True, input_mask_from_dataset=True,
token_type_ids_from_dataset=True, token_type_ids_from_dataset=True,
dtype=mstype.float32, dtype=mstype.float32,
compute_type=mstype.float16, compute_type=mstype.float16
enable_fused_layernorm=True
) )
...@@ -23,12 +23,62 @@ import numpy as np ...@@ -23,12 +23,62 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import log as logger from mindspore import log as logger
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor")
def _get_square_sum(grad):
norm = P.ReduceSum(False)(F.square(grad), ())
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
class GlobalNorm(nn.Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
self.hyper_map = C.HyperMap()
def construct(self, grads):
square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
return global_norms
class ClipByGlobalNorm(nn.Cell):
"""
Clip grads by global norm
"""
def __init__(self, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm()
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
def construct(self, grads):
global_norm = self.global_norm(grads)
cond = P.GreaterEqual()(global_norm, self.clip_norm)
global_norm = F.select(cond, global_norm, self.clip_norm)
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads
class CrossEntropyCalculation(nn.Cell): class CrossEntropyCalculation(nn.Cell):
""" """
Cross Entropy loss Cross Entropy loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册