diff --git a/example/bert_clue/dataset.py b/example/bert_clue/dataset.py index 9dbe7b8ce41e583b5e654b85163208c80ad43191..f930b67330d4ab6123a10a7c0ebe2c415d215737 100644 --- a/example/bert_clue/dataset.py +++ b/example/bert_clue/dataset.py @@ -52,7 +52,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e ds = ds.map(input_columns="input_ids", operations=type_cast_op) # apply batch operations ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True) - ds = ds.repeat(repeat_count) + ds = ds.repeat(new_repeat_count) logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeatcount: {}".format(ds.get_repeat_count())) return ds, new_repeat_count diff --git a/example/bert_clue/run_pretrain.py b/example/bert_clue/run_pretrain.py index c587d41bc321decca5dd35245fbbb9c057c0b4bc..e46806c315a5fcd8c0cf1d3c62da55797ecd498d 100644 --- a/example/bert_clue/run_pretrain.py +++ b/example/bert_clue/run_pretrain.py @@ -81,6 +81,11 @@ def run_pretrain(): context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=device_num) + from mindspore.parallel._auto_parallel_context import auto_parallel_context + if bert_net_cfg.num_hidden_layers == 12: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) + elif bert_net_cfg.num_hidden_layers == 24: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) D.init() rank = args_opt.device_id % device_num else: