提交 ac958362 编写于 作者: C chenhaozhe

add global norm in bert

上级 7371cedd
......@@ -178,12 +178,14 @@ def run_pretrain():
if args_opt.accumulation_steps <= 1:
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell)
scale_update_cell=update_cell,
enable_global_norm=cfg.enable_global_norm)
else:
accumulation_steps = args_opt.accumulation_steps
net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
scale_update_cell=update_cell,
accumulation_steps=accumulation_steps)
accumulation_steps=accumulation_steps,
enable_global_norm=cfg.enable_global_norm)
else:
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
......
......@@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.ops import _selected_ops
from .bert_model import BertModel
from .utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
......@@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.enable_global_norm = enable_global_norm
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
......@@ -419,6 +421,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
# apply grad reducer on grads
grads = self.grad_reducer(grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
if self.enable_global_norm:
grads = ClipByGlobalNorm()(grads)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
......@@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1):
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.enable_global_norm = enable_global_norm
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
......@@ -582,6 +588,9 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
grads = self.grad_reducer(self.accu_grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
if self.enable_global_norm:
grads = ClipByGlobalNorm()(grad)
else:
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
accu_overflow = self.overflow_reducer(accu_overflow)
F.control_depend(grads, accu_overflow)
......
......@@ -24,6 +24,7 @@ cfg = edict({
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Lamb',
'enable_global_norm': False,
'AdamWeightDecay': edict({
'learning_rate': 3e-5,
'end_learning_rate': 0.0,
......@@ -115,6 +116,5 @@ if cfg.bert_network == 'large':
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16,
enable_fused_layernorm=True
compute_type=mstype.float16
)
......@@ -23,12 +23,62 @@ import numpy as np
import mindspore.nn as nn
from mindspore import log as logger
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor")
def _get_square_sum(grad):
norm = P.ReduceSum(False)(F.square(grad), ())
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
class GlobalNorm(nn.Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self):
super(GlobalNorm, self).__init__()
self.norm = nn.Norm()
self.hyper_map = C.HyperMap()
def construct(self, grads):
square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
return global_norms
class ClipByGlobalNorm(nn.Cell):
"""
Clip grads by global norm
"""
def __init__(self, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm()
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
def construct(self, grads):
global_norm = self.global_norm(grads)
cond = P.GreaterEqual()(global_norm, self.clip_norm)
global_norm = F.select(cond, global_norm, self.clip_norm)
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads
class CrossEntropyCalculation(nn.Cell):
"""
Cross Entropy loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册