未验证 提交 638d8f42 编写于 作者: S ShawnXuan 提交者: GitHub

Merge pull request #143 from Oneflow-Inc/dev_test_oneflow_0.2_perf

Add optimization parameters
......@@ -44,4 +44,5 @@ python3 $BENCH_ROOT/of_cnn_train_val.py \
--fuse_bn_add_relu=True \
--nccl_fusion_threshold_mb=16 \
--nccl_fusion_max_ops=24 \
--gpu_image_decoder=True \
--model="resnet50"
......@@ -111,7 +111,8 @@ def get_parser(parser=None):
default=False,
help='Whether to use use fuse batch normalization add relu. Currently supported in origin/master of OneFlow only.'
)
parser.add_argument("--gpu_image_decoder", type=str2bool,
default=False, help='Whether to use use ImageDecoderRandomCropResize.')
# inference
parser.add_argument("--image_path", type=str, default='test_img/tiger.jpg', help="image path")
......
......@@ -35,6 +35,10 @@ def get_train_config(args):
train_config.prune_parallel_cast_ops(True)
train_config.enable_inplace(True)
if args.num_nodes > 1:
train_config.cudnn_conv_heuristic_search_algo(True)
else:
train_config.cudnn_conv_heuristic_search_algo(False)
train_config.enable_fuse_model_update_ops(True)
return train_config
......
......@@ -90,14 +90,19 @@ def load_imagenet_for_training(args):
part_name_suffix_length=5,
random_shuffle=True,
shuffle_after_epoch=True)
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", # seed=seed,
color_space=color_space)
label = flow.data.OFRecordRawDecoder(
ofrecord, "class/label", shape=(), dtype=flow.int32)
if args.gpu_image_decoder:
encoded = flow.data.OFRecordBytesDecoder(ofrecord, "encoded")
image = flow.data.ImageDecoderRandomCropResize(encoded, target_width=224, target_height=224, num_workers=3)
else:
image = flow.data.OFRecordImageDecoderRandomCrop(ofrecord, "encoded", # seed=seed,
color_space=color_space)
rsz = flow.image.Resize(image, target_size=[args.image_size, args.image_size])
image = rsz[0]
rsz = flow.image.Resize(image, target_size=[args.image_size, args.image_size])
rng = flow.random.CoinFlip(batch_size=train_batch_size) # , seed=seed)
normal = flow.image.CropMirrorNormalize(rsz[0], mirror_blob=rng,
normal = flow.image.CropMirrorNormalize(image, mirror_blob=rng,
color_space=color_space, output_layout=output_layout,
mean=args.rgb_mean, std=args.rgb_std, output_dtype=flow.float)
return label, normal
......
rm -rf core.*
rm -rf ./output/snapshots/*
if [ -n "$1" ]; then
NUM_EPOCH=$1
else
NUM_EPOCH=50
fi
echo NUM_EPOCH=$NUM_EPOCH
# training with imagenet
if [ -n "$2" ]; then
DATA_ROOT=$2
else
DATA_ROOT=/data/imagenet/ofrecord
fi
echo DATA_ROOT=$DATA_ROOT
LOG_FOLDER=../logs
mkdir -p $LOG_FOLDER
LOGFILE=$LOG_FOLDER/resnet_training.log
export PYTHONUNBUFFERED=1
echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED
export NCCL_LAUNCH_MODE=PARALLEL
echo NCCL_LAUNCH_MODE=$NCCL_LAUNCH_MODE
python3 of_cnn_train_val.py \
--train_data_dir=$DATA_ROOT/train \
--train_data_part_num=256 \
--val_data_dir=$DATA_ROOT/validation \
--val_data_part_num=256 \
--num_nodes=1 \
--gpu_num_per_node=8 \
--optimizer="sgd" \
--momentum=0.875 \
--label_smoothing=0.1 \
--learning_rate=1.536 \
--loss_print_every_n_iter=100 \
--batch_size_per_device=192 \
--val_batch_size_per_device=50 \
--use_fp16 \
--channel_last=True \
--pad_output \
--fuse_bn_relu=True \
--fuse_bn_add_relu=True \
--nccl_fusion_threshold_mb=16 \
--nccl_fusion_max_ops=24 \
--gpu_image_decoder=True \
--num_epoch=$NUM_EPOCH \
--model="resnet50" 2>&1 | tee ${LOGFILE}
echo "Writting log to ${LOGFILE}"
rm -rf core.*
rm -rf ./output/snapshots/*
if [ -n "$1" ]; then
NUM_EPOCH=$1
else
NUM_EPOCH=50
fi
echo NUM_EPOCH=$NUM_EPOCH
# training with imagenet
if [ -n "$2" ]; then
DATA_ROOT=$2
else
DATA_ROOT=/data/imagenet/ofrecord
fi
echo DATA_ROOT=$DATA_ROOT
LOG_FOLDER=../logs
mkdir -p $LOG_FOLDER
LOGFILE=$LOG_FOLDER/resnet_training.log
export PYTHONUNBUFFERED=1
echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED
export NCCL_LAUNCH_MODE=PARALLEL
echo NCCL_LAUNCH_MODE=$NCCL_LAUNCH_MODE
python3 of_cnn_train_val.py \
--train_data_dir=$DATA_ROOT/train \
--train_data_part_num=256 \
--val_data_dir=$DATA_ROOT/validation \
--val_data_part_num=256 \
--num_nodes=1 \
--gpu_num_per_node=8 \
--optimizer="sgd" \
--momentum=0.875 \
--label_smoothing=0.1 \
--learning_rate=0.768 \
--loss_print_every_n_iter=100 \
--batch_size_per_device=96 \
--val_batch_size_per_device=50 \
--channel_last=False \
--fuse_bn_relu=True \
--fuse_bn_add_relu=True \
--nccl_fusion_threshold_mb=16 \
--nccl_fusion_max_ops=24 \
--gpu_image_decoder=True \
--num_epoch=$NUM_EPOCH \
--model="resnet50" 2>&1 | tee ${LOGFILE}
echo "Writting log to ${LOGFILE}"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册