提交 daed129e 编写于 作者: S ShawnXuan

work with cc's new dataloader

上级 c0e98bd4
...@@ -6,7 +6,7 @@ import argparse ...@@ -6,7 +6,7 @@ import argparse
from datetime import datetime from datetime import datetime
import logging import logging
from dali_util import add_dali_args #from dali_util import add_dali_args
from optimizer_util import add_optimizer_args from optimizer_util import add_optimizer_args
from ofrecord_util import add_ofrecord_args from ofrecord_util import add_ofrecord_args
...@@ -40,7 +40,7 @@ def get_parser(parser=None): ...@@ -40,7 +40,7 @@ def get_parser(parser=None):
parser.add_argument("--batch_size_per_device", type=int, default=64) parser.add_argument("--batch_size_per_device", type=int, default=64)
parser.add_argument("--val_batch_size_per_device", type=int, default=8) parser.add_argument("--val_batch_size_per_device", type=int, default=8)
# for data process # for data process
parser.add_argument("--num_examples", type=int, default=1281167, help="train pic number") parser.add_argument("--num_examples", type=int, default=1281167, help="train pic number")
parser.add_argument("--num_val_examples", type=int, default=50000, help="validation pic number") parser.add_argument("--num_val_examples", type=int, default=50000, help="validation pic number")
parser.add_argument('--rgb-mean', type=float_list, default=[123.68, 116.779, 103.939], parser.add_argument('--rgb-mean', type=float_list, default=[123.68, 116.779, 103.939],
...@@ -48,6 +48,8 @@ def get_parser(parser=None): ...@@ -48,6 +48,8 @@ def get_parser(parser=None):
parser.add_argument('--rgb-std', type=float_list, default=[58.393, 57.12, 57.375], parser.add_argument('--rgb-std', type=float_list, default=[58.393, 57.12, 57.375],
help='a tuple of size 3 for the std rgb') help='a tuple of size 3 for the std rgb')
parser.add_argument("--input_layout", type=str, default='NHWC', help="NCHW or NHWC") parser.add_argument("--input_layout", type=str, default='NHWC', help="NCHW or NHWC")
parser.add_argument('--image-shape', type=int_list, default=[3, 224, 224],
help='the image shape feed into the network')
## snapshot ## snapshot
...@@ -65,7 +67,7 @@ def get_parser(parser=None): ...@@ -65,7 +67,7 @@ def get_parser(parser=None):
default=1, default=1,
help="print loss every n iteration", help="print loss every n iteration",
) )
add_dali_args(parser) #add_dali_args(parser)
add_ofrecord_args(parser) add_ofrecord_args(parser)
add_optimizer_args(parser) add_optimizer_args(parser)
return parser return parser
......
...@@ -44,7 +44,8 @@ def TrainNet(): ...@@ -44,7 +44,8 @@ def TrainNet():
if args.train_data_dir: if args.train_data_dir:
assert os.path.exists(args.train_data_dir) assert os.path.exists(args.train_data_dir)
print("Loading data from {}".format(args.train_data_dir)) print("Loading data from {}".format(args.train_data_dir))
(labels, images) = ofrecord_util.load_imagenet_for_training(args) (labels, images) = ofrecord_util.load_imagenet_for_training2(args)
# note: images.shape = (N C H W) in cc's new dataloader(load_imagenet_for_training2)
else: else:
print("Loading synthetic data.") print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args) (labels, images) = ofrecord_util.load_synthetic(args)
...@@ -63,7 +64,7 @@ def InferenceNet(): ...@@ -63,7 +64,7 @@ def InferenceNet():
if args.val_data_dir: if args.val_data_dir:
assert os.path.exists(args.val_data_dir) assert os.path.exists(args.val_data_dir)
print("Loading data from {}".format(args.val_data_dir)) print("Loading data from {}".format(args.val_data_dir))
(labels, images) = ofrecord_util.load_imagenet_for_validation(args) (labels, images) = ofrecord_util.load_imagenet_for_validation2(args)
else: else:
print("Loading synthetic data.") print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args) (labels, images) = ofrecord_util.load_synthetic(args)
......
...@@ -59,9 +59,9 @@ def load_imagenet_for_validation(args): ...@@ -59,9 +59,9 @@ def load_imagenet_for_validation(args):
codec=flow.data.ImageCodec( codec=flow.data.ImageCodec(
[ [
#flow.data.ImagePreprocessor('bgr2rgb'), #flow.data.ImagePreprocessor('bgr2rgb'),
flow.data.ImageTargetResizePreprocessor(resize_shorter=256), # flow.data.ImageTargetResizePreprocessor(resize_shorter=256),
flow.data.ImageCenterCropPreprocessor(args.image_size, args.image_size), # flow.data.ImageCenterCropPreprocessor(args.image_size, args.image_size),
#flow.data.ImageResizePreprocessor(args.image_size, args.image_size), flow.data.ImageResizePreprocessor(args.image_size, args.image_size),
] ]
) )
return load_imagenet(args, val_batch_size, args.val_data_dir, args.val_data_part_num, codec) return load_imagenet(args, val_batch_size, args.val_data_dir, args.val_data_part_num, codec)
...@@ -98,15 +98,37 @@ def load_imagenet_for_training2(args): ...@@ -98,15 +98,37 @@ def load_imagenet_for_training2(args):
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", seed=seed, image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", seed=seed,
color_space=color_space) color_space=color_space)
label = flow.data.OFRecordRawDecoder(ofrecord, "class/label", shape=(), dtype=flow.int32) label = flow.data.OFRecordRawDecoder(ofrecord, "class/label", shape=(), dtype=flow.int32)
rsz = flow.image.Resize(image, resize_x=float(args.image_size), rsz = flow.image.Resize(image, resize_x=args.image_size,
resize_y=float(args.image_size), resize_y=args.image_size,
color_space=color_space) color_space=color_space)
rng = flow.image.CoinFlip(batch_size=train_batch_size, seed=seed) rng = flow.random.CoinFlip(batch_size=train_batch_size, seed=seed)
normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space, normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space,
mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float) mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float)
return label, normal return label, normal
def load_imagenet_for_validation2(args):
total_device_num = args.num_nodes * args.gpu_num_per_node
val_batch_size = total_device_num * args.val_batch_size_per_device
seed = 0
color_space = 'RGB'
with flow.fixed_placement("cpu", "0:0"):
ofrecord = flow.data.ofrecord_loader(args.val_data_dir,
batch_size=val_batch_size,
data_part_num=args.val_data_part_num,
part_name_suffix_length=5)
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", seed=seed,
color_space=color_space)
label = flow.data.OFRecordRawDecoder(ofrecord, "class/label", shape=(), dtype=flow.int32)
rsz = flow.image.Resize(image, resize_x=args.image_size,
resize_y=args.image_size,
color_space=color_space)
rng = flow.random.CoinFlip(batch_size=val_batch_size, seed=seed)
normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space,
mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float)
return label, normal
if __name__ == "__main__": if __name__ == "__main__":
import os import os
......
...@@ -127,7 +127,9 @@ def resnet_stem(input): ...@@ -127,7 +127,9 @@ def resnet_stem(input):
def resnet50(images, trainable=True): def resnet50(images, trainable=True):
images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2]) # note: images.shape = (N C H W) in cc's new dataloader, transpose is not needed anymore
# images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
print(images.shape, "******************************")
with flow.deprecated.variable_scope("Resnet"): with flow.deprecated.variable_scope("Resnet"):
stem = resnet_stem(images) stem = resnet_stem(images)
......
...@@ -4,14 +4,14 @@ rm -rf ./output/snapshots/* ...@@ -4,14 +4,14 @@ rm -rf ./output/snapshots/*
DATA_ROOT=/dataset/ImageNet/ofrecord DATA_ROOT=/dataset/ImageNet/ofrecord
#DATA_ROOT=/dataset/imagenet-mxnet #DATA_ROOT=/dataset/imagenet-mxnet
#python3 cnn_benchmark/of_cnn_train_val.py \ #python3 cnn_benchmark/of_cnn_train_val.py \
#gdb --args \
#nvprof -f -o resnet.nvvp \ #nvprof -f -o resnet.nvvp \
#gdb --args \
python3 cnn_e2e/of_cnn_train_val.py \ python3 cnn_e2e/of_cnn_train_val.py \
--train_data_dir=$DATA_ROOT/train \ --train_data_dir=$DATA_ROOT/train \
--train_data_part_num=256 \ --train_data_part_num=256 \
--val_data_dir=$DATA_ROOT/validation \ --val_data_dir=$DATA_ROOT/validation \
--val_data_part_num=256 \ --val_data_part_num=256 \
--num_nodes=2 \ --num_nodes=1 \
--node_ips='11.11.1.12,11.11.1.14' \ --node_ips='11.11.1.12,11.11.1.14' \
--gpu_num_per_node=4 \ --gpu_num_per_node=4 \
--optimizer="momentum-cosine-decay" \ --optimizer="momentum-cosine-decay" \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册