提交 06b3d482 编写于 作者: J jonyguo

1. add set_dataset_size for MindDataset 2. modify parameter dupe_factor from 5 to 10

上级 5845c4ad
......@@ -83,7 +83,7 @@ for index in $(seq 0 $file_list_len); do
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 &
--dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & # user defined
process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l`
echo "Total task: ${#file_list[*]}, processing: ${process_count}"
if [ $process_count -ge $avaiable_core_size ]; then
......
......@@ -44,4 +44,4 @@ python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
--dupe_factor=10 # user defined
......@@ -2671,20 +2671,30 @@ class MindDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
if self._dataset_size is None:
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
if self.partitions is not None and self.partitions[0] > 0:
if num_rows % self.partitions[0] == 0:
num_rows = num_rows // self.partitions[0]
else:
if self.num_padded > 0:
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
return num_rows
return self._dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
def is_shuffled(self):
if self.shuffle_option is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册