未验证 提交 ba609980 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #1691 from typhoonzero/add_reader_dist_global_shuffle

add reader dist global shuffle
...@@ -80,9 +80,9 @@ def get_device_num(): ...@@ -80,9 +80,9 @@ def get_device_num():
device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n') device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
return device_num return device_num
def prepare_reader(is_train, pyreader, args): def prepare_reader(is_train, pyreader, args, pass_id=0):
if is_train: if is_train:
reader = train(data_dir=args.data_dir) reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id)
else: else:
reader = val(data_dir=args.data_dir) reader = val(data_dir=args.data_dir)
if is_train: if is_train:
...@@ -262,6 +262,8 @@ def train_parallel(args): ...@@ -262,6 +262,8 @@ def train_parallel(args):
num_samples = 0 num_samples = 0
start_time = time.time() start_time = time.time()
batch_id = 1 batch_id = 1
# use pass_id+1 as per pass global shuffle for distributed training
prepare_reader(True, train_pyreader, args, pass_id + 1)
train_pyreader.start() train_pyreader.start()
while True: while True:
try: try:
......
...@@ -130,11 +130,14 @@ def _reader_creator(file_list, ...@@ -130,11 +130,14 @@ def _reader_creator(file_list,
shuffle=False, shuffle=False,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=DATA_DIR): data_dir=DATA_DIR,
pass_id_as_seed=0):
def reader(): def reader():
with open(file_list) as flist: with open(file_list) as flist:
full_lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
if shuffle: if shuffle:
if pass_id_as_seed:
np.random.seed(pass_id_as_seed)
np.random.shuffle(full_lines) np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
...@@ -166,7 +169,7 @@ def _reader_creator(file_list, ...@@ -166,7 +169,7 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR): def train(data_dir=DATA_DIR, pass_id_as_seed=0):
file_list = os.path.join(data_dir, 'train_list.txt') file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator( return _reader_creator(
file_list, file_list,
...@@ -174,7 +177,8 @@ def train(data_dir=DATA_DIR): ...@@ -174,7 +177,8 @@ def train(data_dir=DATA_DIR):
shuffle=True, shuffle=True,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=data_dir) data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed)
def val(data_dir=DATA_DIR): def val(data_dir=DATA_DIR):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册