未验证 提交 7935b528 编写于 作者: R ruri 提交者: GitHub

fix padding bug and add resize_short_size args (#2428)

* fix padding bug and add resize_short_size args

* fix typo

* pass args:settings in eval and infer
上级 524d3705
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import sys import sys
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader as reader import reader_cv2 as reader
import argparse import argparse
import functools import functools
import models import models
...@@ -25,7 +25,7 @@ add_arg('image_shape', str, "3,224,224", "Input image size") ...@@ -25,7 +25,7 @@ add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('resize_short_size', int, 256, "Set resize short size")
# yapf: enable # yapf: enable
def eval(args): def eval(args):
...@@ -84,7 +84,7 @@ def eval(args): ...@@ -84,7 +84,7 @@ def eval(args):
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
val_reader = paddle.batch(reader.val(), batch_size=args.batch_size) val_reader = paddle.batch(reader.val(settings=args), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
test_info = [[], [], []] test_info = [[], [], []]
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import sys import sys
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader import reader_cv2 as reader
import argparse import argparse
import functools import functools
import models import models
...@@ -25,6 +25,7 @@ add_arg('with_mem_opt', bool, True, "Whether to use memory o ...@@ -25,6 +25,7 @@ add_arg('with_mem_opt', bool, True, "Whether to use memory o
add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('save_inference', bool, False, "Whether to save inference model or not") add_arg('save_inference', bool, False, "Whether to save inference model or not")
add_arg('resize_short_size', int, 256, "Set resize short size")
# yapf: enable # yapf: enable
def infer(args): def infer(args):
...@@ -78,7 +79,7 @@ def infer(args): ...@@ -78,7 +79,7 @@ def infer(args):
print("model: ",model_name," is already saved") print("model: ",model_name," is already saved")
exit(0) exit(0)
test_batch_size = 1 test_batch_size = 1
test_reader = paddle.batch(reader.test(), batch_size=test_batch_size) test_reader = paddle.batch(reader.test(settings=args), batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image]) feeder = fluid.DataFeeder(place=place, feed_list=[image])
TOPK = 1 TOPK = 1
......
...@@ -92,7 +92,7 @@ class ResNet(): ...@@ -92,7 +92,7 @@ class ResNet():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
......
...@@ -104,7 +104,7 @@ class ResNet(): ...@@ -104,7 +104,7 @@ class ResNet():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
...@@ -140,7 +140,7 @@ class ResNet(): ...@@ -140,7 +140,7 @@ class ResNet():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=1, stride=1,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
......
...@@ -101,7 +101,7 @@ class ResNeXt(): ...@@ -101,7 +101,7 @@ class ResNeXt():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
...@@ -137,7 +137,7 @@ class ResNeXt(): ...@@ -137,7 +137,7 @@ class ResNeXt():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=1, stride=1,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
......
...@@ -173,7 +173,7 @@ class SE_ResNeXt(): ...@@ -173,7 +173,7 @@ class SE_ResNeXt():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
bias_attr=False, bias_attr=False,
...@@ -205,7 +205,7 @@ class SE_ResNeXt(): ...@@ -205,7 +205,7 @@ class SE_ResNeXt():
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=1, stride=1,
padding=(filter_size - 1) / 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, act=None,
param_attr=ParamAttr(name=name + "_weights"), param_attr=ParamAttr(name=name + "_weights"),
......
...@@ -162,7 +162,7 @@ def process_image( ...@@ -162,7 +162,7 @@ def process_image(
else: else:
if crop_size > 0: if crop_size > 0:
target_size = settings.resize_short_size target_size = settings.resize_short_size
img = resize_short(img, 256) img = resize_short(img, target_size)
img = crop_image(img, target_size=crop_size, center=True) img = crop_image(img, target_size=crop_size, center=True)
...@@ -222,14 +222,14 @@ def _reader_creator(settings, ...@@ -222,14 +222,14 @@ def _reader_creator(settings,
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
yield [img_path] yield [img_path]
crop_size = int(settings.image_shape.split(",")[2])
image_mapper = functools.partial( image_mapper = functools.partial(
process_image, process_image,
settings=settings, settings=settings,
mode=mode, mode=mode,
color_jitter=color_jitter, color_jitter=color_jitter,
rotate=rotate, rotate=rotate,
crop_size=224) crop_size=crop_size)
reader = paddle.reader.xmap_readers( reader = paddle.reader.xmap_readers(
image_mapper, reader, THREAD, BUF_SIZE, order=False) image_mapper, reader, THREAD, BUF_SIZE, order=False)
return reader return reader
......
...@@ -38,7 +38,7 @@ add_arg('pretrained_model', str, None, "Whether to use pretrai ...@@ -38,7 +38,7 @@ add_arg('pretrained_model', str, None, "Whether to use pretrai
add_arg('checkpoint', str, None, "Whether to resume checkpoint.") add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.1, "set learning rate.") add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "ResNet50", "Set the network to use.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet dataset root dir.") add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet dataset root dir.")
add_arg('fp16', bool, False, "Enable half precision training with fp16." ) add_arg('fp16', bool, False, "Enable half precision training with fp16." )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册