提交 c95f6dd5 编写于 作者: S ShawnXuan

support cc data proc

上级 75025ae0
......@@ -82,3 +82,30 @@ def load_synthetic(args):
)
return label, image
def load_imagenet_for_training2(args):
total_device_num = args.num_nodes * args.gpu_num_per_node
train_batch_size = total_device_num * args.batch_size_per_device
seed = 0
color_space = 'RGB'
with flow.fixed_placement("cpu", "0:0"):
ofrecord = flow.data.ofrecord_loader(args.train_data_dir,
batch_size=train_batch_size,
data_part_num=args.train_data_part_num,
part_name_suffix_length=5)
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", seed=seed,
color_space=color_space)
label = flow.data.OFRecordRawDecoder(ofrecord, "class/label", shape=(), dtype=flow.int32)
rsz = flow.image.Resize(image, resize_x=float(args.image_size),
resize_y=float(args.image_size),
color_space=color_space)
print(rsz.shape)
print(label.shape)
rng = flow.image.CoinFlip(batch_size=train_batch_size, seed=seed)
normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space,
mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float)
print(normal.shape)
return label, normal
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册