diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 4b80b472fa46782f1dbe7f9f807c274a5e95dc01..0d85aff4c820862617e97bfc0462b974954d006c 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -121,9 +121,10 @@ def run_pretrain(): new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps if args_opt.train_steps > 0: - new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) + train_steps = args_opt.train_steps * args_opt.accumulation_steps + new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps) else: - args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() + args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps logger.info("train steps: {}".format(args_opt.train_steps)) if cfg.optimizer == 'Lamb': diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index b57f93143a87203c3683af3048e05a249469ca35..7426f73add45794bdd3ecd5ec035869e9b094743 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -487,9 +487,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow") self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss") - self.grad = C.GradOperation('grad', - get_by_list=True, - sens_param=True) + self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.reducer_flag = False self.parallel_mode = context.get_auto_parallel_context("parallel_mode") if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: