未验证 提交 4d87afd6 编写于 作者: L liu zhengxi 提交者: GitHub

Fix hung (#5121)

* fix hung

* add shuffle batch

* update

* reader_seed to shuffle_seed

* seed for shuffle batch
上级 047b8b69
......@@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 16
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training:
# The number of epoches for training
......
......@@ -43,6 +43,12 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]
if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
......@@ -69,7 +75,8 @@ def create_data_loader(args):
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
......@@ -83,6 +90,9 @@ def create_data_loader(args):
if m == "train":
batch_sampler = batch_sampler.shard()
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
......
......@@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training:
# The number of epoches for training
......
......@@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global"
batch_size: 4096
infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training:
# The number of epoches for training
......
......@@ -43,6 +43,12 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"]
]
if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed
def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
......@@ -69,7 +75,8 @@ def create_data_loader(args):
key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size)
else:
sampler = sampler.shuffle()
if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
if args.sort_type == SortType.POOL:
buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
......@@ -83,6 +90,9 @@ def create_data_loader(args):
if m == "train":
batch_sampler = batch_sampler.shard()
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册