未验证 提交 7ac5b3cd 编写于 作者: B baiyf 提交者: GitHub

Merge pull request #986 from baiyfbupt/develop

Refine Pyramidbox infer and Pyramidbox configure
model/
pretrained/
data/
label/
pretrained/
*.swp
*.log
infer_results/
......@@ -11,7 +11,6 @@ import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
......@@ -20,73 +19,272 @@ add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The data root path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('resize_h', int, 0, "The resized image height.")
add_arg('resize_w', int, 0, "The resized image height.")
# yapf: enable
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
for dt in nms_out:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
xmin, ymin, xmax, ymax, score = dt
if score < confs_threshold:
continue
bbox = dt[2:]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
image.save('./infer_results/' + image_class.encode('utf-8') + '/' +
image_name.encode('utf-8'))
def infer(args, data_args):
num_classes = 2
infer_reader = reader.infer(data_args, args.image_path)
data = infer_reader()
def write_to_txt(image_path, f, nms_out):
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
f.write('{:s}\n'.format(
image_class.encode('utf-8') + '/' + image_name.encode('utf-8')))
f.write('{:d}\n'.format(nms_out.shape[0]))
for dt in nms_out:
xmin, ymin, xmax, ymax, score = dt
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, (
xmax - xmin + 1), (ymax - ymin + 1), score))
print("image infer result saved {}".format(image_name[:-4]))
if args.resize_h and args.resize_w:
image_shape = [3, args.resize_h, args.resize_w]
else:
image_shape = data.shape[1:]
fetches = []
def get_round(x, loc):
str_x = str(x)
if '.' in str_x:
len_after = len(str_x.split('.')[1])
str_before = str_x.split('.')[0]
str_after = str_x.split('.')[1]
if len_after >= 3:
str_final = str_before + '.' + str_after[0:loc]
return float(str_final)
else:
return x
def bbox_vote(det):
order = det[:, 4].ravel().argsort()[::-1]
det = det[order, :]
if det.shape[0] == 0:
dets = np.array([[10, 10, 20, 20, 0.002]])
det = np.empty(shape=[0, 5])
while det.shape[0] > 0:
# IOU
area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
xx1 = np.maximum(det[0, 0], det[:, 0])
yy1 = np.maximum(det[0, 1], det[:, 1])
xx2 = np.minimum(det[0, 2], det[:, 2])
yy2 = np.minimum(det[0, 3], det[:, 3])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
o = inter / (area[0] + area[:] - inter)
# get needed merge det and delete these det
merge_index = np.where(o >= 0.3)[0]
det_accu = det[merge_index, :]
det = np.delete(det, merge_index, 0)
if merge_index.shape[0] <= 1:
if det.shape[0] == 0:
try:
dets = np.row_stack((dets, det_accu))
except:
dets = det_accu
continue
det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
max_score = np.max(det_accu[:, 4])
det_accu_sum = np.zeros((1, 5))
det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4],
axis=0) / np.sum(det_accu[:, -1:])
det_accu_sum[:, 4] = max_score
try:
dets = np.row_stack((dets, det_accu_sum))
except:
dets = det_accu_sum
dets = dets[0:750, :]
return dets
def image_preprocess(image):
img = np.array(image)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= np.array(
[104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32')
img = img * 0.007843
img = [img]
img = np.array(img)
return img
network = PyramidBox(
image_shape,
num_classes,
sub_network=args.use_pyramidbox,
is_infer=True)
infer_program, nmsed_out = network.infer()
fetches = [nmsed_out]
def detect_face(image, shrink):
image_shape = [3, image.size[1], image.size[0]]
num_classes = 2
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
model_dir = args.model_dir
if not os.path.exists(model_dir):
raise ValueError("The model path [%s] does not exist." % (model_dir))
if shrink != 1:
image = image.resize((int(image_shape[2] * shrink),
int(image_shape[1] * shrink)), Image.ANTIALIAS)
image_shape = [
image_shape[0], int(image_shape[1] * shrink),
int(image_shape[2] * shrink)
]
print "image_shape:", image_shape
img = image_preprocess(image)
scope = fluid.core.Scope()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.scope_guard(scope):
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
fetches = []
network = PyramidBox(
image_shape,
num_classes,
sub_network=args.use_pyramidbox,
is_infer=True)
infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out]
fluid.io.load_persistables(
exe, args.model_dir, main_program=main_program)
detection, = exe.run(infer_program,
feed={'image': img},
fetch_list=fetches,
return_numpy=False)
detection = np.array(detection)
# layout: xmin, ymin, xmax. ymax, score
det_conf = detection[:, 1]
det_xmin = image_shape[2] * detection[:, 2] / shrink
det_ymin = image_shape[1] * detection[:, 3] / shrink
det_xmax = image_shape[2] * detection[:, 4] / shrink
det_ymax = image_shape[1] * detection[:, 5] / shrink
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
keep_index = np.where(det[:, 4] >= 0)[0]
det = det[keep_index, :]
return det
def flip_test(image, shrink):
img = image.transpose(Image.FLIP_LEFT_RIGHT)
det_f = detect_face(img, shrink)
det_t = np.zeros(det_f.shape)
# image.size: [width, height]
det_t[:, 0] = image.size[0] - det_f[:, 2]
det_t[:, 1] = det_f[:, 1]
det_t[:, 2] = image.size[0] - det_f[:, 0]
det_t[:, 3] = det_f[:, 3]
det_t[:, 4] = det_f[:, 4]
return det_t
def multi_scale_test(image, max_shrink):
# shrink detecting and shrink only detect big face
st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink
det_s = detect_face(image, st)
index = np.where(
np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1)
> 30)[0]
det_s = det_s[index, :]
# enlarge one times
bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2
det_b = detect_face(image, bt)
# enlarge small image x times for small face
if max_shrink > 2:
bt *= 2
while bt < max_shrink:
det_b = np.row_stack((det_b, detect_face(image, bt)))
bt *= 2
det_b = np.row_stack((det_b, detect_face(image, max_shrink)))
# enlarge only detect small face
if bt > 1:
index = np.where(
np.minimum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
det_b = det_b[index, :]
else:
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) > 30)[0]
det_b = det_b[index, :]
return det_s, det_b
def get_im_shrink(image_shape):
max_shrink_v1 = (0x7fffffff / 577.0 /
(image_shape[1] * image_shape[2]))**0.5
max_shrink_v2 = (
(678 * 1024 * 2.0 * 2.0) / (image_shape[1] * image_shape[2]))**0.5
max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3
if max_shrink >= 1.5 and max_shrink < 2:
max_shrink = max_shrink - 0.1
elif max_shrink >= 2 and max_shrink < 3:
max_shrink = max_shrink - 0.2
elif max_shrink >= 3 and max_shrink < 4:
max_shrink = max_shrink - 0.3
elif max_shrink >= 4 and max_shrink < 5:
max_shrink = max_shrink - 0.4
elif max_shrink >= 5:
max_shrink = max_shrink - 0.5
print 'max_shrink = ', max_shrink
shrink = max_shrink if max_shrink < 1 else 1
print "shrink = ", shrink
return shrink, max_shrink
def infer(args, batch_size, data_args):
if not os.path.exists(args.model_dir):
raise ValueError("The model path [%s] does not exist." %
(args.model_dir))
infer_reader = paddle.batch(
reader.test(data_args, file_list), batch_size=batch_size)
for batch_id, img in enumerate(infer_reader()):
image = img[0][0]
image_path = img[0][1]
# image.size: [width, height]
image_shape = [3, image.size[1], image.size[0]]
shrink, max_shrink = get_im_shrink(image_shape)
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
det0 = detect_face(image, shrink)
det1 = flip_test(image, shrink)
[det2, det3] = multi_scale_test(image, max_shrink)
det = np.row_stack((det0, det1, det2, det3))
dets = bbox_vote(det)
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
if not os.path.exists('./infer_results/' + image_class.encode('utf-8')):
os.makedirs('./infer_results/' + image_class.encode('utf-8'))
feed = {'image': fluid.create_lod_tensor(data, [], place)}
predict, = exe.run(infer_program,
feed=feed,
fetch_list=fetches,
return_numpy=False)
predict = np.array(predict)
draw_bounding_box_on_image(args.image_path, predict, args.confs_threshold)
f = open('./infer_results/' + image_class.encode('utf-8') + '/' +
image_name.encode('utf-8')[:-4] + '.txt', 'w')
write_to_txt(image_path, f, dets)
# draw_bounding_box_on_image(image_path, dets, args.confs_threshold)
print "Done"
if __name__ == '__main__':
......@@ -98,10 +296,8 @@ if __name__ == '__main__':
data_args = reader.Settings(
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[104., 117., 123],
apply_distort=False,
apply_expand=False,
ap_version='11point')
infer(args, data_args=data_args)
infer(args, batch_size=1, data_args=data_args)
......@@ -39,7 +39,11 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True):
act='relu')
if with_pool:
pool = fluid.layers.pool2d(
input=conv, pool_size=2, pool_type='max', pool_stride=2)
input=conv,
pool_size=2,
pool_type='max',
pool_stride=2,
ceil_mode=True)
return conv, pool
else:
return conv
......@@ -148,6 +152,8 @@ class PyramidBox(object):
b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.))
conv2 = fluid.layers.conv2d(
up_to, ch, 1, act='relu', bias_attr=b_attr)
if self.is_infer:
upsampling = fluid.layers.crop(upsampling, shape=conv2)
# eltwise mul
conv_fuse = upsampling * conv2
return conv_fuse
......@@ -393,8 +399,11 @@ class PyramidBox(object):
total_loss = face_loss + head_loss
return face_loss, head_loss, total_loss
def infer(self):
test_program = fluid.default_main_program().clone(for_test=True)
def infer(self, main_program=None):
if main_program is None:
test_program = fluid.default_main_program().clone(for_test=True)
else:
test_program = main_program.clone(for_test=True)
with fluid.program_guard(test_program):
face_nmsed_out = fluid.layers.detection_output(
self.face_mbox_loc,
......
......@@ -238,34 +238,38 @@ def pyramidbox(settings, file_list, mode, shuffle):
im_width, im_height = im.size
# layout: label | xmin | ymin | xmax | ymax
bbox_labels = []
for index_box in range(len(dict_input_txt[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = dict_input_txt[index_image][
index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(1)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
bbox_sample.append(float(ymax) / im_height)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 1:5]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
yield im, boxes, expand_bboxes(boxes), lbls, difficults
if mode == 'train':
bbox_labels = []
for index_box in range(len(dict_input_txt[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = dict_input_txt[index_image][
index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(1)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
bbox_sample.append(float(ymax) / im_height)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 1:5]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
yield im, boxes, expand_bboxes(boxes), lbls, difficults
if mode == 'test':
yield im, image_path
return reader
......@@ -274,6 +278,10 @@ def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)
def test(settings, file_list):
return pyramidbox(settings, file_list, 'test', False)
def infer(settings, image_path):
def batch_reader():
img = Image.open(image_path)
......
......@@ -16,11 +16,11 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "parallel")
add_arg('learning_rate', float, 0.0001, "Learning rate.")
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 12, "Minibatch size.")
add_arg('num_passes', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
......@@ -50,10 +50,10 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
fetches = [loss]
epocs = 12880 / batch_size
boundaries = [epocs * 100, epocs * 125, epocs * 150]
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
values = [
learning_rate, learning_rate * 0.1, learning_rate * 0.01,
learning_rate * 0.001
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
]
if optimizer_method == "momentum":
......@@ -70,12 +70,19 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
)
optimizer.minimize(loss)
# fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
start_pass = 0
if pretrained_model:
if pretrained_model.isdigit():
start_pass = int(pretrained_model) + 1
pretrained_model = os.path.join(args.model_save_dir, pretrained_model)
print("Resume from %s " %(pretrained_model))
if not os.path.exists(pretrained_model):
raise ValueError("The pre-trained model path [%s] does not exist." %
(pretrained_model))
......@@ -98,14 +105,14 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
for pass_id in range(num_passes):
for pass_id in range(start_pass, num_passes):
start_time = time.time()
prev_start_time = start_time
end_time = 0
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
if len(data) < devices_num: continue
if len(data) < 2 * devices_num: continue
if args.parallel:
fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches],
feed=feeder.feed(data))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册