未验证 提交 6a70cbd3 编写于 作者: W whs 提交者: GitHub

Fix dataloader when distributed training (#1242)

上级 185b3c96
......@@ -122,6 +122,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
def main():
rank_id = paddle.distributed.get_rank()
place = paddle.CUDAPlace(rank_id)
global global_config
all_config = load_slim_config(args.config_path)
......@@ -155,6 +157,7 @@ def main():
train_loader = DataLoader(
train_dataset,
places=[place],
batch_size=global_config['batch_size'],
shuffle=True,
drop_last=True,
......
......@@ -146,7 +146,8 @@ def reader_wrapper(reader):
if __name__ == '__main__':
rank_id = paddle.distributed.get_rank()
place = paddle.CUDAPlace(rank_id)
args = parse_args()
paddle.enable_static()
# step1: load dataset config and create dataloader
......@@ -160,6 +161,7 @@ if __name__ == '__main__':
drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
places=[place],
batch_sampler=batch_sampler,
num_workers=2,
return_list=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册