提交 ff500c67 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2122 add set_dataset_size for MindDataset

Merge pull request !2122 from guozhijian/add_set_dataset_size_for_minddataset
...@@ -83,7 +83,7 @@ for index in $(seq 0 $file_list_len); do ...@@ -83,7 +83,7 @@ for index in $(seq 0 $file_list_len); do
--max_predictions_per_seq=20 \ --max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \ --masked_lm_prob=0.15 \
--random_seed=12345 \ --random_seed=12345 \
--dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 & --dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}" echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then if [ $process_count -ge $avaiable_core_size ]; then
......
...@@ -44,4 +44,4 @@ python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched ...@@ -44,4 +44,4 @@ python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched
--max_predictions_per_seq=20 \ --max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \ --masked_lm_prob=0.15 \
--random_seed=12345 \ --random_seed=12345 \
--dupe_factor=5 --dupe_factor=10 # user defined
...@@ -2577,20 +2577,30 @@ class MindDataset(SourceDataset): ...@@ -2577,20 +2577,30 @@ class MindDataset(SourceDataset):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
if self.load_dataset: if self._dataset_size is None:
dataset_file = [self.dataset_file] if self.load_dataset:
else: dataset_file = [self.dataset_file]
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
else: else:
if self.num_padded > 0: dataset_file = self.dataset_file
raise RuntimeError( num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
"Dataset size plus number of padded samples is not divisible by number of shards.") if self.partitions is not None and self.partitions[0] > 0:
num_rows = num_rows // self.partitions[0] + 1 if num_rows % self.partitions[0] == 0:
return num_rows num_rows = num_rows // self.partitions[0]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows
return self._dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
def is_shuffled(self): def is_shuffled(self):
if self.shuffle_option is None: if self.shuffle_option is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册