未验证 提交 5799742f 编写于 作者: W whs 提交者: GitHub

Fix dataloader (#1245)

上级 6a70cbd3
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
import os import os
import sys import sys
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
import argparse import argparse
import functools import functools
from functools import partial from functools import partial
...@@ -60,7 +58,7 @@ def reader_wrapper(reader, input_name): ...@@ -60,7 +58,7 @@ def reader_wrapper(reader, input_name):
return gen return gen
def eval_reader(data_dir, batch_size, crop_size, resize_size): def eval_reader(data_dir, batch_size, crop_size, resize_size, place=None):
val_reader = ImageNetDataset( val_reader = ImageNetDataset(
mode='val', mode='val',
data_dir=data_dir, data_dir=data_dir,
...@@ -68,6 +66,7 @@ def eval_reader(data_dir, batch_size, crop_size, resize_size): ...@@ -68,6 +66,7 @@ def eval_reader(data_dir, batch_size, crop_size, resize_size):
resize_size=resize_size) resize_size=resize_size)
val_loader = DataLoader( val_loader = DataLoader(
val_reader, val_reader,
places=[place] if place is not None else None,
batch_size=global_config['batch_size'], batch_size=global_config['batch_size'],
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
...@@ -171,13 +170,14 @@ def main(): ...@@ -171,13 +170,14 @@ def main():
save_dir=args.save_dir, save_dir=args.save_dir,
config=all_config, config=all_config,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_callback=eval_function, eval_callback=eval_function if rank_id == 0 else None,
eval_dataloader=reader_wrapper( eval_dataloader=reader_wrapper(
eval_reader( eval_reader(
data_dir, data_dir,
global_config['batch_size'], global_config['batch_size'],
crop_size=img_size, crop_size=img_size,
resize_size=resize_size), resize_size=resize_size,
place=place),
global_config['input_name'])) global_config['input_name']))
ac.compress() ac.compress()
......
...@@ -66,10 +66,6 @@ def parse_args(): ...@@ -66,10 +66,6 @@ def parse_args():
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
nranks = paddle.distributed.ParallelEnv().local_rank
if nranks > 1 and paddle.distributed.get_rank() != 0:
return
batch_sampler = paddle.io.BatchSampler( batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False) eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader( loader = paddle.io.DataLoader(
...@@ -168,6 +164,11 @@ if __name__ == '__main__': ...@@ -168,6 +164,11 @@ if __name__ == '__main__':
worker_init_fn=worker_init_fn) worker_init_fn=worker_init_fn)
train_dataloader = reader_wrapper(train_loader) train_dataloader = reader_wrapper(train_loader)
nranks = paddle.distributed.get_world_size()
rank_id = paddle.distributed.get_rank()
if nranks > 1 and rank_id != 0:
eval_function = None
# step2: create and instance of AutoCompression # step2: create and instance of AutoCompression
ac = AutoCompression( ac = AutoCompression(
model_dir=args.model_dir, model_dir=args.model_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册