提交 d57960ed 编写于 作者: S shibeiji

delete the redundant argument while initializing class of GradOperation

上级 7371cedd
...@@ -121,9 +121,10 @@ def run_pretrain(): ...@@ -121,9 +121,10 @@ def run_pretrain():
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
if args_opt.train_steps > 0: if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) train_steps = args_opt.train_steps * args_opt.accumulation_steps
new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps)
else: else:
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps
logger.info("train steps: {}".format(args_opt.train_steps)) logger.info("train steps: {}".format(args_opt.train_steps))
if cfg.optimizer == 'Lamb': if cfg.optimizer == 'Lamb':
......
...@@ -487,9 +487,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): ...@@ -487,9 +487,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow") self.accu_overflow = Parameter(initializer(0, [1], mstype.int32), name="accu_overflow")
self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss") self.loss = Parameter(initializer(0, [1], mstype.float32), name="accu_loss")
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True, sens_param=True)
get_by_list=True,
sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册