未验证 提交 a27a43ec 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #1807 from littletomatodonkey/dyg/fix_seed

fix data replication for multi-cards sampling
......@@ -51,7 +51,7 @@ signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger):
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet']
......@@ -61,7 +61,7 @@ def build_dataloader(config, mode, device, logger):
assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test."
dataset = eval(module_name)(config, mode, logger)
dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last']
......
......@@ -21,7 +21,7 @@ from .imaug import transform, create_operators
class LMDBDateSet(Dataset):
def __init__(self, config, mode, logger):
def __init__(self, config, mode, logger, seed=None):
super(LMDBDateSet, self).__init__()
global_config = config['Global']
......
......@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
......@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
......@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
......@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
......
......@@ -182,8 +182,8 @@ def train(config,
start_epoch = 1
for epoch in range(start_epoch, epoch_num + 1):
if epoch > 0:
train_dataloader = build_dataloader(config, 'Train', device, logger)
train_dataloader = build_dataloader(
config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册