From 113d8a8eb5a6261bc150f208926953ddff8e0add Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Fri, 26 Mar 2021 05:41:36 +0000 Subject: [PATCH] fix eval res vary for different times --- ppocr/data/simple_dataset.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index d2a86b0f..ea57b785 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -23,6 +23,7 @@ class SimpleDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): super(SimpleDataSet, self).__init__() self.logger = logger + self.mode = mode.lower() global_config = config['Global'] dataset_config = config[mode]['dataset'] @@ -45,7 +46,7 @@ class SimpleDataSet(Dataset): logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_idx_order_list = list(range(len(self.data_lines))) - if mode.lower() == "train": + if self.mode == "train" and self.do_shuffle: self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) @@ -56,16 +57,16 @@ class SimpleDataSet(Dataset): for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - random.seed(self.seed) - lines = random.sample(lines, - round(len(lines) * ratio_list[idx])) + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) data_lines.extend(lines) return data_lines def shuffle_data_random(self): - if self.do_shuffle: - random.seed(self.seed) - random.shuffle(self.data_lines) + random.seed(self.seed) + random.shuffle(self.data_lines) return def __getitem__(self, idx): @@ -90,7 +91,10 @@ class SimpleDataSet(Dataset): data_line, e)) outs = None if outs is None: - return self.__getitem__(np.random.randint(self.__len__())) + # during evaluation, we should fix the idx to get same results for many times of evaluation. + rnd_idx = np.random.randint(self.__len__( + )) if self.mode == "train" else (idx + 1) % self.__len__() + return self.__getitem__(rnd_idx) return outs def __len__(self): -- GitLab