From ac95836257d5870e82c031388fa1988e5f75d5fe Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Tue, 14 Jul 2020 15:08:46 +0800 Subject: [PATCH] add global norm in bert --- model_zoo/official/nlp/bert/run_pretrain.py | 6 ++- .../nlp/bert/src/bert_for_pre_training.py | 17 +++++-- model_zoo/official/nlp/bert/src/config.py | 4 +- model_zoo/official/nlp/bert/src/utils.py | 50 +++++++++++++++++++ 4 files changed, 69 insertions(+), 8 deletions(-) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 4b80b472f..bb7f9c271 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -178,12 +178,14 @@ def run_pretrain(): if args_opt.accumulation_steps <= 1: net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer, - scale_update_cell=update_cell) + scale_update_cell=update_cell, + enable_global_norm=cfg.enable_global_norm) else: accumulation_steps = args_opt.accumulation_steps net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer, scale_update_cell=update_cell, - accumulation_steps=accumulation_steps) + accumulation_steps=accumulation_steps, + enable_global_norm=cfg.enable_global_norm) else: net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index b57f93143..658e2770b 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -29,6 +29,7 @@ from mindspore.communication.management import get_group_size from mindspore import context from mindspore.ops import _selected_ops from .bert_model import BertModel +from .utils import ClipByGlobalNorm GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_VALUE = 1.0 @@ -348,11 +349,12 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): optimizer (Optimizer): Optimizer for updating the weights. scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ - def __init__(self, network, optimizer, scale_update_cell=None): + def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.weights = optimizer.parameters self.optimizer = optimizer + self.enable_global_norm = enable_global_norm self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.reducer_flag = False @@ -419,7 +421,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.enable_global_norm: + grads = ClipByGlobalNorm()(grads) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: @@ -474,12 +479,13 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = batch_size * accumulation_steps. Default: 1. """ - def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1): + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network self.weights = optimizer.parameters self.optimizer = optimizer self.accumulation_steps = accumulation_steps + self.enable_global_norm = enable_global_norm self.one = Tensor(np.array([1]).astype(np.int32)) self.zero = Tensor(np.array([0]).astype(np.int32)) self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step") @@ -582,7 +588,10 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): grads = self.grad_reducer(self.accu_grads) scaling = scaling_sens * self.degree * self.accumulation_steps grads = self.hyper_map(F.partial(grad_scale, scaling), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.enable_global_norm: + grads = ClipByGlobalNorm()(grad) + else: + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) accu_overflow = self.overflow_reducer(accu_overflow) F.control_depend(grads, accu_overflow) overflow = self.less_equal(self.base, accu_overflow) diff --git a/model_zoo/official/nlp/bert/src/config.py b/model_zoo/official/nlp/bert/src/config.py index d0da37f60..32c8bb86a 100644 --- a/model_zoo/official/nlp/bert/src/config.py +++ b/model_zoo/official/nlp/bert/src/config.py @@ -24,6 +24,7 @@ cfg = edict({ 'scale_factor': 2, 'scale_window': 1000, 'optimizer': 'Lamb', + 'enable_global_norm': False, 'AdamWeightDecay': edict({ 'learning_rate': 3e-5, 'end_learning_rate': 0.0, @@ -115,6 +116,5 @@ if cfg.bert_network == 'large': input_mask_from_dataset=True, token_type_ids_from_dataset=True, dtype=mstype.float32, - compute_type=mstype.float16, - enable_fused_layernorm=True + compute_type=mstype.float16 ) diff --git a/model_zoo/official/nlp/bert/src/utils.py b/model_zoo/official/nlp/bert/src/utils.py index 46d8591e2..56422a07d 100644 --- a/model_zoo/official/nlp/bert/src/utils.py +++ b/model_zoo/official/nlp/bert/src/utils.py @@ -23,12 +23,62 @@ import numpy as np import mindspore.nn as nn from mindspore import log as logger from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype from mindspore.train.callback import Callback from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR +get_square_sum = C.MultitypeFuncGraph("get_square_sum") +@get_square_sum.register("Tensor") +def _get_square_sum(grad): + norm = P.ReduceSum(False)(F.square(grad), ()) + norm = F.expand_dims(F.cast(norm, mstype.float32), 0) + return norm + + +apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") +@apply_global_norm.register("Tensor", "Tensor", "Tensor") +def _apply_global_norm(clip_norm, global_norm, grad): + grad = grad * clip_norm / global_norm + return grad + + +class GlobalNorm(nn.Cell): + """ + Calculate the global norm value of given tensors + """ + def __init__(self): + super(GlobalNorm, self).__init__() + self.norm = nn.Norm() + self.hyper_map = C.HyperMap() + + def construct(self, grads): + square_sum = self.hyper_map(get_square_sum, grads) + global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) + return global_norms + + +class ClipByGlobalNorm(nn.Cell): + """ + Clip grads by global norm + """ + def __init__(self, clip_norm=1.0): + super(ClipByGlobalNorm, self).__init__() + self.global_norm = GlobalNorm() + self.clip_norm = Tensor([clip_norm], mstype.float32) + self.hyper_map = C.HyperMap() + + def construct(self, grads): + global_norm = self.global_norm(grads) + cond = P.GreaterEqual()(global_norm, self.clip_norm) + global_norm = F.select(cond, global_norm, self.clip_norm) + grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) + return grads + + class CrossEntropyCalculation(nn.Cell): """ Cross Entropy loss -- GitLab