“4fe0621cc49b37eeded0b6853c6a9facd89a6685”上不存在“bsp/stm32/stm32f091-st-nucleo/rtconfig.h”
未验证 提交 6a70cbd3 编写于 作者: W whs 提交者: GitHub

Fix dataloader when distributed training (#1242)

上级 185b3c96
...@@ -122,6 +122,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): ...@@ -122,6 +122,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
def main(): def main():
rank_id = paddle.distributed.get_rank()
place = paddle.CUDAPlace(rank_id)
global global_config global global_config
all_config = load_slim_config(args.config_path) all_config = load_slim_config(args.config_path)
...@@ -155,6 +157,7 @@ def main(): ...@@ -155,6 +157,7 @@ def main():
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
places=[place],
batch_size=global_config['batch_size'], batch_size=global_config['batch_size'],
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
......
...@@ -146,7 +146,8 @@ def reader_wrapper(reader): ...@@ -146,7 +146,8 @@ def reader_wrapper(reader):
if __name__ == '__main__': if __name__ == '__main__':
rank_id = paddle.distributed.get_rank()
place = paddle.CUDAPlace(rank_id)
args = parse_args() args = parse_args()
paddle.enable_static() paddle.enable_static()
# step1: load dataset config and create dataloader # step1: load dataset config and create dataloader
...@@ -160,6 +161,7 @@ if __name__ == '__main__': ...@@ -160,6 +161,7 @@ if __name__ == '__main__':
drop_last=True) drop_last=True)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
places=[place],
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
num_workers=2, num_workers=2,
return_list=True, return_list=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册