未验证 提交 622fdd51 编写于 作者: R ruri 提交者: GitHub

refine warning in image classification (#4018)

* refine warning
上级 7bc60080
...@@ -261,8 +261,6 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.80 ...@@ -261,8 +261,6 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.80
python -m paddle.distributed.launch train.py \ python -m paddle.distributed.launch train.py \
--model=ShuffleNetV2_x0_25 \ --model=ShuffleNetV2_x0_25 \
--batch_size=2048 \ --batch_size=2048 \
--class_dim=1000 \
--image_shape=3,224,224 \
--lr_strategy=cosine_decay_warmup \ --lr_strategy=cosine_decay_warmup \
--num_epochs=240 \ --num_epochs=240 \
--lr=0.5 \ --lr=0.5 \
......
...@@ -256,8 +256,6 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.80 ...@@ -256,8 +256,6 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.80
python -m paddle.distributed.launch train.py \ python -m paddle.distributed.launch train.py \
--model=ShuffleNetV2_x0_25 \ --model=ShuffleNetV2_x0_25 \
--batch_size=2048 \ --batch_size=2048 \
--class_dim=1000 \
--image_shape=3,224,224 \
--lr_strategy=cosine_decay_warmup \ --lr_strategy=cosine_decay_warmup \
--num_epochs=240 \ --num_epochs=240 \
--lr=0.5 \ --lr=0.5 \
......
...@@ -52,8 +52,6 @@ add_arg('use_se', bool, True, "Whether to use Squeeze- ...@@ -52,8 +52,6 @@ add_arg('use_se', bool, True, "Whether to use Squeeze-
def eval(args): def eval(args):
image_shape = args.image_shape
model_list = [m for m in dir(models) if "__" not in m] model_list = [m for m in dir(models) if "__" not in m]
assert args.model in model_list, "{} is not in lists: {}".format(args.model, assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list) model_list)
...@@ -62,8 +60,11 @@ def eval(args): ...@@ -62,8 +60,11 @@ def eval(args):
), "{} doesn't exist, please load right pretrained model path for eval".format( ), "{} doesn't exist, please load right pretrained model path for eval".format(
args.pretrained_model) args.pretrained_model)
assert args.image_shape[
1] <= args.resize_short_size, "Please check the args:image_shape and args:resize_short_size, The croped size(image_shape[1]) must smaller than or equal to the resized length(resize_short_size) "
image = fluid.data( image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32') name='image', shape=[None] + args.image_shape, dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64') label = fluid.data(name='label', shape=[None, 1], dtype='int64')
# model definition # model definition
......
...@@ -54,14 +54,17 @@ add_arg('use_se', bool, True, "Whether to use Squeeze- ...@@ -54,14 +54,17 @@ add_arg('use_se', bool, True, "Whether to use Squeeze-
def infer(args): def infer(args):
image_shape = args.image_shape
model_list = [m for m in dir(models) if "__" not in m] model_list = [m for m in dir(models) if "__" not in m]
assert args.model in model_list, "{} is not in lists: {}".format(args.model, assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list) model_list)
assert os.path.isdir(args.pretrained_model assert os.path.isdir(args.pretrained_model
), "please load right pretrained model path for infer" ), "please load right pretrained model path for infer"
assert args.image_shape[
1] <= args.resize_short_size, "Please check the args:image_shape and args:resize_short_size, The croped size(image_shape[1]) must smaller than or equal to the resized length(resize_short_size) "
image = fluid.data( image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32') name='image', shape=[None] + args.image_shape, dtype='float32')
if args.model.startswith('EfficientNet'): if args.model.startswith('EfficientNet'):
model = models.__dict__[args.model](is_test=True, model = models.__dict__[args.model](is_test=True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册