未验证 提交 5799742f 编写于 作者: W whs 提交者: GitHub

Fix dataloader (#1245)

上级 6a70cbd3
......@@ -14,8 +14,6 @@
import os
import sys
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
import argparse
import functools
from functools import partial
......@@ -60,7 +58,7 @@ def reader_wrapper(reader, input_name):
return gen
def eval_reader(data_dir, batch_size, crop_size, resize_size):
def eval_reader(data_dir, batch_size, crop_size, resize_size, place=None):
val_reader = ImageNetDataset(
mode='val',
data_dir=data_dir,
......@@ -68,6 +66,7 @@ def eval_reader(data_dir, batch_size, crop_size, resize_size):
resize_size=resize_size)
val_loader = DataLoader(
val_reader,
places=[place] if place is not None else None,
batch_size=global_config['batch_size'],
shuffle=False,
drop_last=False,
......@@ -171,13 +170,14 @@ def main():
save_dir=args.save_dir,
config=all_config,
train_dataloader=train_dataloader,
eval_callback=eval_function,
eval_callback=eval_function if rank_id == 0 else None,
eval_dataloader=reader_wrapper(
eval_reader(
data_dir,
global_config['batch_size'],
crop_size=img_size,
resize_size=resize_size),
resize_size=resize_size,
place=place),
global_config['input_name']))
ac.compress()
......
......@@ -66,10 +66,6 @@ def parse_args():
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
nranks = paddle.distributed.ParallelEnv().local_rank
if nranks > 1 and paddle.distributed.get_rank() != 0:
return
batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
......@@ -168,6 +164,11 @@ if __name__ == '__main__':
worker_init_fn=worker_init_fn)
train_dataloader = reader_wrapper(train_loader)
nranks = paddle.distributed.get_world_size()
rank_id = paddle.distributed.get_rank()
if nranks > 1 and rank_id != 0:
eval_function = None
# step2: create and instance of AutoCompression
ac = AutoCompression(
model_dir=args.model_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册