提交 f1e327e0 编写于 作者: T typhoonzero

add reader dist global shuffle

上级 2de73266
......@@ -80,9 +80,9 @@ def get_device_num():
device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
return device_num
def prepare_reader(is_train, pyreader, args):
def prepare_reader(is_train, pyreader, args, pass_id=0):
if is_train:
reader = train(data_dir=args.data_dir)
reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id)
else:
reader = val(data_dir=args.data_dir)
if is_train:
......@@ -262,6 +262,8 @@ def train_parallel(args):
num_samples = 0
start_time = time.time()
batch_id = 1
# use pass_id+1 as per pass global shuffle for distributed training
prepare_reader(True, train_pyreader, args, pass_id + 1)
train_pyreader.start()
while True:
try:
......
......@@ -130,11 +130,14 @@ def _reader_creator(file_list,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR):
data_dir=DATA_DIR,
pass_id_as_seed=0):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
if pass_id_as_seed:
np.random.seed(pass_id_as_seed)
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
......@@ -166,7 +169,7 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR):
def train(data_dir=DATA_DIR, pass_id_as_seed=0):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
......@@ -174,7 +177,8 @@ def train(data_dir=DATA_DIR):
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=data_dir)
data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed)
def val(data_dir=DATA_DIR):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册