提交 7015966e 编写于 作者: S ShawnXuan

new dataloader

上级 528682fb
......@@ -6,6 +6,7 @@ import oneflow as flow
def add_ofrecord_args(parser):
parser.add_argument("--image_size", type=int, default=224, required=False, help="image size")
parser.add_argument("--resize_shorter", type=int, default=256, required=False, help="resize shorter for validation")
parser.add_argument("--train_data_dir", type=str, default=None, help="train dataset directory")
parser.add_argument("--train_data_part_num", type=int, default=256, help="train data part num")
parser.add_argument("--val_data_dir", type=str, default=None, help="val dataset directory")
......@@ -88,22 +89,21 @@ def load_imagenet_for_training2(args):
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
seed = 0
color_space = 'RGB'
with flow.fixed_placement("cpu", "0:0"):
with flow.fixed_placement("cpu", "0:0-{}".format(args.gpu_num_per_node - 1)):
ofrecord = flow.data.ofrecord_reader(args.train_data_dir,
batch_size=train_batch_size,
data_part_num=args.train_data_part_num,
part_name_suffix_length=5,
random_shuffle = True,
shuffle_after_epoch=True)
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", seed=seed,
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", #seed=seed,
color_space=color_space)
label = flow.data.OFRecordRawDecoder(ofrecord, "class/label", shape=(), dtype=flow.int32)
rsz = flow.image.Resize(image, resize_x=args.image_size,
resize_y=args.image_size,
rsz = flow.image.Resize(image, resize_x=args.image_size, resize_y=args.image_size,
color_space=color_space)
rng = flow.random.CoinFlip(batch_size=train_batch_size, seed=seed)
rng = flow.random.CoinFlip(batch_size=train_batch_size)#, seed=seed)
normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space,
mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float)
return label, normal
......@@ -112,23 +112,19 @@ def load_imagenet_for_validation2(args):
total_device_num = args.num_nodes * args.gpu_num_per_node
val_batch_size = total_device_num * args.val_batch_size_per_device
seed = 0
color_space = 'RGB'
with flow.fixed_placement("cpu", "0:0"):
with flow.fixed_placement("cpu", "0:0-{}".format(args.gpu_num_per_node - 1)):
ofrecord = flow.data.ofrecord_reader(args.val_data_dir,
batch_size=val_batch_size,
data_part_num=args.val_data_part_num,
part_name_suffix_length=5,
shuffle_after_epoch=False)
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", seed=seed,
color_space=color_space)
image = flow.data.OFRecordImageDecoder(ofrecord, "encoded", color_space=color_space)
label = flow.data.OFRecordRawDecoder(ofrecord, "class/label", shape=(), dtype=flow.int32)
rsz = flow.image.Resize(image, resize_x=args.image_size,
resize_y=args.image_size,
color_space=color_space)
rsz = flow.image.Resize(image, resize_shorter=args.resize_shorter, color_space=color_space)
rng = flow.random.CoinFlip(batch_size=val_batch_size, seed=seed)
normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space,
normal = flow.image.CropMirrorNormalize(rsz, color_space=color_space,
crop_h = args.image_size, crop_w = args.image_size, crop_pos_y = 0.5, crop_pos_x = 0.5,
mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float)
return label, normal
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册