diff --git a/model_zoo/mass/src/dataset/load_dataset.py b/model_zoo/mass/src/dataset/load_dataset.py index 53ad5c74911643224aac65c7112e913490186995..be599413745007da5681b28ea208e428a4be61a0 100644 --- a/model_zoo/mass/src/dataset/load_dataset.py +++ b/model_zoo/mass/src/dataset/load_dataset.py @@ -58,9 +58,6 @@ def _load_dataset(input_files, batch_size, epoch_count=1, ori_dataset_size = ds.get_dataset_size() print(f" | Dataset size: {ori_dataset_size}.") repeat_count = epoch_count - if sink_mode: - ds.set_dataset_size(sink_step * batch_size) - repeat_count = epoch_count * ori_dataset_size // ds.get_dataset_size() type_cast_op = deC.TypeCast(mstype.int32) ds = ds.map(input_columns="src", operations=type_cast_op) diff --git a/model_zoo/mass/train.py b/model_zoo/mass/train.py index a0c1959265956adacbfdd9e0e12305ba4dbf070c..07e4469bd5308a875d9bfc8157a3f9fe39c47825 100644 --- a/model_zoo/mass/train.py +++ b/model_zoo/mass/train.py @@ -79,11 +79,15 @@ def _train(model, config: TransformerConfig, if pre_training_dataset is not None: print(" | Start pre-training job.") - epoch_size = pre_training_dataset.get_repeat_count() + epoch_size = config.epochs * pre_training_dataset.get_dataset_size() // config.dataset_sink_step + if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: print(f" | Rank {MultiAscend.get_rank()} Call model train.") + model.train(epoch_size, pre_training_dataset, - callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode) + callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, + sink_size=config.dataset_sink_step) + # Test the accuracy of the model. if test_dataset is not None: print(" | Start test job.") @@ -93,10 +97,11 @@ def _train(model, config: TransformerConfig, if fine_tune_dataset is not None: print(" | Start fine-tuning job.") - epoch_size = fine_tune_dataset.get_repeat_count() + epoch_size = config.epochs * fine_tune_dataset.get_dataset_size() // config.dataset_sink_step model.train(epoch_size, fine_tune_dataset, - callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode) + callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, + sink_size=config.dataset_sink_step) # Test the accuracy of the model. if test_dataset is not None: