提交 679fa655 编写于 作者: G Guanghua Yu 提交者: qingqing01

Add global shuffle for data reader in object_detection.

上级 9afe4f67
...@@ -3,7 +3,61 @@ from paddle.fluid.initializer import MSRA ...@@ -3,7 +3,61 @@ from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
def conv_bn(input, class MobileNetSSD:
def __init__(self, img, num_classes, img_shape):
self.img = img
self.num_classes = num_classes
self.img_shape = img_shape
def ssd_net(self, scale=1.0):
# 300x300
tmp = self.conv_bn(self.img, 3, int(32 * scale), 2, 1, 3)
# 150x150
tmp = self.depthwise_separable(tmp, 32, 64, 32, 1, scale)
tmp = self.depthwise_separable(tmp, 64, 128, 64, 2, scale)
# 75x75
tmp = self.depthwise_separable(tmp, 128, 128, 128, 1, scale)
tmp = self.depthwise_separable(tmp, 128, 256, 128, 2, scale)
# 38x38
tmp = self.depthwise_separable(tmp, 256, 256, 256, 1, scale)
tmp = self.depthwise_separable(tmp, 256, 512, 256, 2, scale)
# 19x19
for i in range(5):
tmp = self.depthwise_separable(tmp, 512, 512, 512, 1, scale)
module11 = tmp
tmp = self.depthwise_separable(tmp, 512, 1024, 512, 2, scale)
# 10x10
module13 = self.depthwise_separable(tmp, 1024, 1024, 1024, 1, scale)
module14 = self.extra_block(module13, 256, 512, 1, 2, scale)
# 5x5
module15 = self.extra_block(module14, 128, 256, 1, 2, scale)
# 3x3
module16 = self.extra_block(module15, 128, 256, 1, 2, scale)
# 2x2
module17 = self.extra_block(module16, 64, 128, 1, 2, scale)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[
module11, module13, module14, module15, module16, module17
],
image=self.img,
num_classes=self.num_classes,
min_ratio=20,
max_ratio=90,
min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.],
[2., 3.]],
base_size=self.img_shape[2],
offset=0.5,
flip=True)
return mbox_locs, mbox_confs, box, box_var
def conv_bn(self,
input,
filter_size, filter_size,
num_filters, num_filters,
stride, stride,
...@@ -26,10 +80,9 @@ def conv_bn(input, ...@@ -26,10 +80,9 @@ def conv_bn(input,
bias_attr=False) bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act) return fluid.layers.batch_norm(input=conv, act=act)
def depthwise_separable(self, input, num_filters1, num_filters2, num_groups,
def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, stride, scale):
scale): depthwise_conv = self.conv_bn(
depthwise_conv = conv_bn(
input=input, input=input,
filter_size=3, filter_size=3,
num_filters=int(num_filters1 * scale), num_filters=int(num_filters1 * scale),
...@@ -38,7 +91,7 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, ...@@ -38,7 +91,7 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
num_groups=int(num_groups * scale), num_groups=int(num_groups * scale),
use_cudnn=False) use_cudnn=False)
pointwise_conv = conv_bn( pointwise_conv = self.conv_bn(
input=depthwise_conv, input=depthwise_conv,
filter_size=1, filter_size=1,
num_filters=int(num_filters2 * scale), num_filters=int(num_filters2 * scale),
...@@ -46,10 +99,10 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride, ...@@ -46,10 +99,10 @@ def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
padding=0) padding=0)
return pointwise_conv return pointwise_conv
def extra_block(self, input, num_filters1, num_filters2, num_groups, stride,
def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): scale):
# 1x1 conv # 1x1 conv
pointwise_conv = conv_bn( pointwise_conv = self.conv_bn(
input=input, input=input,
filter_size=1, filter_size=1,
num_filters=int(num_filters1 * scale), num_filters=int(num_filters1 * scale),
...@@ -58,7 +111,7 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): ...@@ -58,7 +111,7 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
padding=0) padding=0)
# 3x3 conv # 3x3 conv
normal_conv = conv_bn( normal_conv = self.conv_bn(
input=pointwise_conv, input=pointwise_conv,
filter_size=3, filter_size=3,
num_filters=int(num_filters2 * scale), num_filters=int(num_filters2 * scale),
...@@ -68,46 +121,6 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale): ...@@ -68,46 +121,6 @@ def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
return normal_conv return normal_conv
def mobile_net(num_classes, img, img_shape, scale=1.0): def build_mobilenet_ssd(img, num_classes, img_shape):
# 300x300 ssd_model = MobileNetSSD(img, num_classes, img_shape)
tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3) return ssd_model.ssd_net()
# 150x150
tmp = depthwise_separable(tmp, 32, 64, 32, 1, scale)
tmp = depthwise_separable(tmp, 64, 128, 64, 2, scale)
# 75x75
tmp = depthwise_separable(tmp, 128, 128, 128, 1, scale)
tmp = depthwise_separable(tmp, 128, 256, 128, 2, scale)
# 38x38
tmp = depthwise_separable(tmp, 256, 256, 256, 1, scale)
tmp = depthwise_separable(tmp, 256, 512, 256, 2, scale)
# 19x19
for i in range(5):
tmp = depthwise_separable(tmp, 512, 512, 512, 1, scale)
module11 = tmp
tmp = depthwise_separable(tmp, 512, 1024, 512, 2, scale)
# 10x10
module13 = depthwise_separable(tmp, 1024, 1024, 1024, 1, scale)
module14 = extra_block(module13, 256, 512, 1, 2, scale)
# 5x5
module15 = extra_block(module14, 128, 256, 1, 2, scale)
# 3x3
module16 = extra_block(module15, 128, 256, 1, 2, scale)
# 2x2
module17 = extra_block(module16, 64, 128, 1, 2, scale)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[module11, module13, module14, module15, module16, module17],
image=img,
num_classes=num_classes,
min_ratio=20,
max_ratio=90,
min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]],
base_size=img_shape[2],
offset=0.5,
flip=True)
return mbox_locs, mbox_confs, box, box_var
...@@ -293,6 +293,7 @@ def train(settings, ...@@ -293,6 +293,7 @@ def train(settings,
coco_api = COCO(file_path) coco_api = COCO(file_path)
image_ids = coco_api.getImgIds() image_ids = coco_api.getImgIds()
images = coco_api.loadImgs(image_ids) images = coco_api.loadImgs(image_ids)
np.random.shuffle(images)
n = int(math.ceil(len(images) // num_workers)) n = int(math.ceil(len(images) // num_workers))
image_lists = [images[i:i + n] for i in range(0, len(images), n)] image_lists = [images[i:i + n] for i in range(0, len(images), n)]
...@@ -307,11 +308,11 @@ def train(settings, ...@@ -307,11 +308,11 @@ def train(settings,
data_dir)) data_dir))
else: else:
images = [line.strip() for line in open(file_path)] images = [line.strip() for line in open(file_path)]
np.random.shuffle(images)
n = int(math.ceil(len(images) // num_workers)) n = int(math.ceil(len(images) // num_workers))
image_lists = [images[i:i + n] for i in range(0, len(images), n)] image_lists = [images[i:i + n] for i in range(0, len(images), n)]
for l in image_lists: for l in image_lists:
readers.append(pascalvoc(settings, l, 'train', batch_size, shuffle)) readers.append(pascalvoc(settings, l, 'train', batch_size, shuffle))
return paddle.reader.multiprocess_reader(readers, False) return paddle.reader.multiprocess_reader(readers, False)
...@@ -341,7 +342,7 @@ def infer(settings, image_path): ...@@ -341,7 +342,7 @@ def infer(settings, image_path):
"data path correctly." % image_path) "data path correctly." % image_path)
img = Image.open(image_path) img = Image.open(image_path)
if img.mode == 'L': if img.mode == 'L':
img = im.convert('RGB') img = img.convert('RGB')
im_width, im_height = img.size im_width, im_height = img.size
img = img.resize((settings.resize_w, settings.resize_h), img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS) Image.ANTIALIAS)
......
...@@ -10,7 +10,7 @@ import multiprocessing ...@@ -10,7 +10,7 @@ import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import reader import reader
from mobilenet_ssd import mobile_net from mobilenet_ssd import build_mobilenet_ssd
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -92,7 +92,7 @@ def build_program(main_prog, startup_prog, train_params, is_train): ...@@ -92,7 +92,7 @@ def build_program(main_prog, startup_prog, train_params, is_train):
use_double_buffer=True) use_double_buffer=True)
with fluid.unique_name.guard(): with fluid.unique_name.guard():
image, gt_box, gt_label, difficult = fluid.layers.read_file(py_reader) image, gt_box, gt_label, difficult = fluid.layers.read_file(py_reader)
locs, confs, box, box_var = mobile_net(class_num, image, image_shape) locs, confs, box, box_var = build_mobilenet_ssd(image, class_num, image_shape)
if is_train: if is_train:
with fluid.unique_name.guard("train"): with fluid.unique_name.guard("train"):
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
...@@ -228,6 +228,13 @@ def train(args, ...@@ -228,6 +228,13 @@ def train(args,
total_time = 0.0 total_time = 0.0
for epoc_id in range(epoc_num): for epoc_id in range(epoc_num):
train_reader = reader.train(data_args,
train_file_list,
batch_size_per_device,
shuffle=is_shuffle,
num_workers=num_workers,
enable_ce=enable_ce)
train_py_reader.decorate_paddle_reader(train_reader)
epoch_idx = epoc_id + 1 epoch_idx = epoc_id + 1
start_time = time.time() start_time = time.time()
prev_start_time = start_time prev_start_time = start_time
...@@ -255,9 +262,10 @@ def train(args, ...@@ -255,9 +262,10 @@ def train(args,
end_time = time.time() end_time = time.time()
total_time += end_time - start_time total_time += end_time - start_time
if epoc_id % 10 == 0 or epoc_id == epoc_num - 1:
best_map, mean_map = test(epoc_id, best_map) best_map, mean_map = test(epoc_id, best_map)
print("Best test map {0}".format(best_map)) print("Best test map {0}".format(best_map))
if epoc_id % 10 == 0 or epoc_id == epoc_num - 1: # save model
save_model(str(epoc_id), train_prog) save_model(str(epoc_id), train_prog)
if enable_ce: if enable_ce:
...@@ -275,7 +283,7 @@ def train(args, ...@@ -275,7 +283,7 @@ def train(args,
(devices_num, total_time / epoch_idx)) (devices_num, total_time / epoch_idx))
if __name__ == '__main__': def main():
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
...@@ -318,3 +326,7 @@ if __name__ == '__main__': ...@@ -318,3 +326,7 @@ if __name__ == '__main__':
train_parameters[dataset], train_parameters[dataset],
train_file_list=train_file_list, train_file_list=train_file_list,
val_file_list=val_file_list) val_file_list=val_file_list)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册