From c4fcd14354146dc7f97efd2d484b7ac4d15b7506 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 26 Nov 2020 21:32:33 +0800 Subject: [PATCH] refine dynamic sampling (#1256) --- ppocr/data/simple_dataset.py | 66 ++++++++++-------------------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 817b8fdb..097da768 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -32,12 +32,10 @@ class SimpleDataSet(Dataset): self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') data_source_num = len(label_file_list) - if data_source_num == 1: - ratio_list = [1.0] - else: - ratio_list = dataset_config.pop('ratio_list') + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * len(data_source_num) - assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1." assert len( ratio_list ) == data_source_num, "The length of ratio_list should be the same as the file_list." @@ -45,62 +43,32 @@ class SimpleDataSet(Dataset): self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) - self.data_lines_list, data_num_list = self.get_image_info_list( - label_file_list) - self.data_idx_order_list = self.dataset_traversal( - data_num_list, ratio_list, batch_size) - self.shuffle_data_random() - + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - def get_image_info_list(self, file_list): + def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] - data_lines_list = [] - data_num_list = [] - for file in file_list: + data_lines = [] + for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - data_lines_list.append(lines) - data_num_list.append(len(lines)) - return data_lines_list, data_num_list - - def dataset_traversal(self, data_num_list, ratio_list, batch_size): - select_num_list = [] - dataset_num = len(data_num_list) - for dno in range(dataset_num): - select_num = round(batch_size * ratio_list[dno]) - select_num = max(select_num, 1) - select_num_list.append(select_num) - data_idx_order_list = [] - cur_index_sets = [0] * dataset_num - while True: - finish_read_num = 0 - for dataset_idx in range(dataset_num): - cur_index = cur_index_sets[dataset_idx] - if cur_index >= data_num_list[dataset_idx]: - finish_read_num += 1 - else: - select_num = select_num_list[dataset_idx] - for sno in range(select_num): - cur_index = cur_index_sets[dataset_idx] - if cur_index >= data_num_list[dataset_idx]: - break - data_idx_order_list.append((dataset_idx, cur_index)) - cur_index_sets[dataset_idx] += 1 - if finish_read_num == dataset_num: - break - return data_idx_order_list + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) + return data_lines def shuffle_data_random(self): if self.do_shuffle: - for dno in range(len(self.data_lines_list)): - random.shuffle(self.data_lines_list[dno]) + random.shuffle(self.data_lines) return def __getitem__(self, idx): - dataset_idx, file_idx = self.data_idx_order_list[idx] - data_line = self.data_lines_list[dataset_idx][file_idx] + file_idx = self.data_idx_order_list[idx] + data_line = self.data_lines[file_idx] try: data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split(self.delimiter) -- GitLab