未验证 提交 d6ee6bdb 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #2347 from littletomatodonkey/dyg/fix_pre_rec

fix eval res vary for different times
......@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
global_config = config['Global']
dataset_config = config[mode]['dataset']
......@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
......@@ -56,16 +57,16 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
......@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
data_line, e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册