diff --git a/fluid/DeepASR/data_utils/async_data_reader.py b/fluid/DeepASR/data_utils/async_data_reader.py index 0c8d010755cc4a947507aeb6a65343f0d160f2be..edface051129b248bad85978118daec6f8660adc 100644 --- a/fluid/DeepASR/data_utils/async_data_reader.py +++ b/fluid/DeepASR/data_utils/async_data_reader.py @@ -207,7 +207,7 @@ class AsyncDataReader(object): feature_file_list, label_file_list="", drop_frame_len=512, - split_sentence_threshold=512, + split_sentence_threshold=1024, proc_num=10, sample_buffer_size=1024, sample_info_buffer_size=1024, diff --git a/fluid/DeepASR/examples/aishell/profile.sh b/fluid/DeepASR/examples/aishell/profile.sh index a7397c308749341b11c4b3d0d2166ec077559834..e7df868b9ea26db3d91be0c01d0b7ecb63c374de 100644 --- a/fluid/DeepASR/examples/aishell/profile.sh +++ b/fluid/DeepASR/examples/aishell/profile.sh @@ -1,7 +1,7 @@ -export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=0 python -u ../../tools/profile.py --feature_lst data/train_feature.lst \ --label_lst data/train_label.lst \ - --mean_var data/aishell/global_mean_var \ - --parallel \ + --mean_var data/global_mean_var \ --frame_dim 80 \ --class_num 3040 \ + --batch_size 16 diff --git a/fluid/DeepASR/examples/aishell/train.sh b/fluid/DeepASR/examples/aishell/train.sh index c536c93263bab262a7abc45dc087a1980f6e45d1..06fe488d4572782d946e8daa7c22ded8ef0212c6 100644 --- a/fluid/DeepASR/examples/aishell/train.sh +++ b/fluid/DeepASR/examples/aishell/train.sh @@ -1,9 +1,9 @@ -export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=4,5,6,7 python -u ../../train.py --train_feature_lst data/train_feature.lst \ --train_label_lst data/train_label.lst \ --val_feature_lst data/val_feature.lst \ --val_label_lst data/val_label.lst \ - --mean_var data/aishell/global_mean_var \ + --mean_var data/global_mean_var \ --checkpoints checkpoints \ --frame_dim 80 \ --class_num 3040 \ @@ -11,4 +11,3 @@ python -u ../../train.py --train_feature_lst data/train_feature.lst \ --batch_size 64 \ --learning_rate 6.4e-5 \ --parallel -~ diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index 07e2d6fc56a639db51930dac24a99bef6d1dc3c4..2461852cd42f0b8d0f0a1d7c9ef8828f02074cf1 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -187,7 +187,12 @@ def infer_from_ckpt(args): infer_program = fluid.default_main_program().clone() - optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.learning_rate, + decay_steps=1879, + decay_rate=1 / 1.2, + staircase=True)) optimizer.minimize(avg_cost) place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0) diff --git a/fluid/DeepASR/tools/profile.py b/fluid/DeepASR/tools/profile.py index 801252c848d28a42fa0e18bd4f41d323db3bc217..d25e18f7db0111acf76e66478f8230aab1d5f760 100644 --- a/fluid/DeepASR/tools/profile.py +++ b/fluid/DeepASR/tools/profile.py @@ -137,7 +137,12 @@ def profile(args): class_num=args.class_num, parallel=args.parallel) - optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.learning_rate, + decay_steps=1879, + decay_rate=1 / 1.2, + staircase=True)) optimizer.minimize(avg_cost) place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0) @@ -150,7 +155,8 @@ def profile(args): trans_splice.TransSplice(5, 5), trans_delay.TransDelay(5) ] - data_reader = reader.AsyncDataReader(args.feature_lst, args.label_lst, -1) + data_reader = reader.AsyncDataReader( + args.feature_lst, args.label_lst, -1, split_sentence_threshold=1024) data_reader.set_transformers(ltrans) feature_t = fluid.LoDTensor() diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 6073db0d07a436f40ac78e38ef072dd23b9dbad5..1c35a6637f534abf4a37763fe1915c35e18e1f94 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -159,7 +159,12 @@ def train(args): test_program = fluid.default_main_program().clone() #optimizer = fluid.optimizer.Momentum(learning_rate=args.learning_rate, momentum=0.9) - optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) + optimizer = fluid.optimizer.Adam( + learning_rate=fluid.layers.exponential_decay( + learning_rate=args.learning_rate, + decay_steps=1879, + decay_rate=1 / 1.2, + staircase=True)) optimizer.minimize(avg_cost) place = fluid.CPUPlace() if args.device == 'CPU' else fluid.CUDAPlace(0) @@ -186,8 +191,11 @@ def train(args): os.path.exists(args.val_label_lst)): return -1.0, -1.0 # test data reader - test_data_reader = reader.AsyncDataReader(args.val_feature_lst, - args.val_label_lst, -1) + test_data_reader = reader.AsyncDataReader( + args.val_feature_lst, + args.val_label_lst, + -1, + split_sentence_threshold=1024) test_data_reader.set_transformers(ltrans) test_costs, test_accs = [], [] for batch_id, batch_data in enumerate( @@ -212,8 +220,11 @@ def train(args): return np.mean(test_costs), np.mean(test_accs) # train data reader - train_data_reader = reader.AsyncDataReader(args.train_feature_lst, - args.train_label_lst, -1) + train_data_reader = reader.AsyncDataReader( + args.train_feature_lst, + args.train_label_lst, + -1, + split_sentence_threshold=1024) train_data_reader.set_transformers(ltrans) # train diff --git a/fluid/face_detection/.gitignore b/fluid/face_detection/.gitignore index 13d42af893162c1908a39fea1d072a22929e5430..eeee7d7057bcb73e738e6df94c702a9e8c5dced6 100644 --- a/fluid/face_detection/.gitignore +++ b/fluid/face_detection/.gitignore @@ -4,4 +4,6 @@ data/ label/ *.swp *.log -infer_results/ +log* +output* +infer_results* diff --git a/fluid/face_detection/data_util.py b/fluid/face_detection/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..ac022593119e0008c3f7f3858303cbf5bc717650 --- /dev/null +++ b/fluid/face_detection/data_util.py @@ -0,0 +1,151 @@ +""" +This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py +""" + +import time +import numpy as np +import threading +import multiprocessing +try: + import queue +except ImportError: + import Queue as queue + + +class GeneratorEnqueuer(object): + """ + Builds a queue out of a data generator. + + Args: + generator: a generator function which endlessly yields data + use_multiprocessing (bool): use multiprocessing if True, + otherwise use threading. + wait_time (float): time to sleep in-between calls to `put()`. + random_seed (int): Initial seed for workers, + will be incremented by one for each workers. + """ + + def __init__(self, + generator, + use_multiprocessing=False, + wait_time=0.05, + random_seed=None): + self.wait_time = wait_time + self._generator = generator + self._use_multiprocessing = use_multiprocessing + self._threads = [] + self._stop_event = None + self.queue = None + self._manager = None + self.seed = random_seed + + def start(self, workers=1, max_queue_size=10): + """ + Start worker threads which add data from the generator into the queue. + + Args: + workers (int): number of worker threads + max_queue_size (int): queue size + (when full, threads could block on `put()`) + """ + + def data_generator_task(): + """ + Data generator task. + """ + + def task(): + if (self.queue is not None and + self.queue.qsize() < max_queue_size): + generator_output = next(self._generator) + self.queue.put((generator_output)) + else: + time.sleep(self.wait_time) + + if not self._use_multiprocessing: + while not self._stop_event.is_set(): + with self.genlock: + try: + task() + except Exception: + self._stop_event.set() + break + else: + while not self._stop_event.is_set(): + try: + task() + except Exception: + self._stop_event.set() + break + + try: + if self._use_multiprocessing: + self._manager = multiprocessing.Manager() + self.queue = self._manager.Queue(maxsize=max_queue_size) + self._stop_event = multiprocessing.Event() + else: + self.genlock = threading.Lock() + self.queue = queue.Queue() + self._stop_event = threading.Event() + for _ in range(workers): + if self._use_multiprocessing: + # Reset random seed else all children processes + # share the same seed + np.random.seed(self.seed) + thread = multiprocessing.Process(target=data_generator_task) + thread.daemon = True + if self.seed is not None: + self.seed += 1 + else: + thread = threading.Thread(target=data_generator_task) + self._threads.append(thread) + thread.start() + except: + self.stop() + raise + + def is_running(self): + """ + Returns: + bool: Whether the worker theads are running. + """ + return self._stop_event is not None and not self._stop_event.is_set() + + def stop(self, timeout=None): + """ + Stops running threads and wait for them to exit, if necessary. + Should be called by the same thread which called `start()`. + + Args: + timeout(int|None): maximum time to wait on `thread.join()`. + """ + if self.is_running(): + self._stop_event.set() + for thread in self._threads: + if self._use_multiprocessing: + if thread.is_alive(): + thread.terminate() + else: + thread.join(timeout) + if self._manager: + self._manager.shutdown() + + self._threads = [] + self._stop_event = None + self.queue = None + + def get(self): + """ + Creates a generator to extract data from the queue. + Skip the data if it is `None`. + + # Yields + tuple of data in the queue. + """ + while self.is_running(): + if not self.queue.empty(): + inputs = self.queue.get() + if inputs is not None: + yield inputs + else: + time.sleep(self.wait_time) diff --git a/fluid/face_detection/image_util.py b/fluid/face_detection/image_util.py index 0d583396cb99439676c0bb44c4fc0ef9643de318..f39538285637c1a284c4058130be40d89435dcef 100644 --- a/fluid/face_detection/image_util.py +++ b/fluid/face_detection/image_util.py @@ -3,6 +3,7 @@ from PIL import ImageFile import numpy as np import random import math +import cv2 ImageFile.LOAD_TRUNCATED_IMAGES = True #otherwise IOError raised image file is truncated @@ -100,6 +101,76 @@ def generate_sample(sampler, image_width, image_height): return sampled_bbox +def data_anchor_sampling(sampler, bbox_labels, image_width, image_height, + scale_array, resize_width, resize_height): + num_gt = len(bbox_labels) + # np.random.randint range: [low, high) + rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0 + + if num_gt != 0: + norm_xmin = bbox_labels[rand_idx][1] + norm_ymin = bbox_labels[rand_idx][2] + norm_xmax = bbox_labels[rand_idx][3] + norm_ymax = bbox_labels[rand_idx][4] + + xmin = norm_xmin * image_width + ymin = norm_ymin * image_height + wid = image_width * (norm_xmax - norm_xmin) + hei = image_height * (norm_ymax - norm_ymin) + range_size = 0 + + for scale_ind in range(0, len(scale_array) - 1): + area = wid * hei + if area > scale_array[scale_ind] ** 2 and area < \ + scale_array[scale_ind + 1] ** 2: + range_size = scale_ind + 1 + break + + scale_choose = 0.0 + if range_size == 0: + rand_idx_size = range_size + 1 + else: + # np.random.randint range: [low, high) + rng_rand_size = np.random.randint(0, range_size) + rand_idx_size = rng_rand_size % range_size + + scale_choose = random.uniform(scale_array[rand_idx_size] / 2.0, + 2.0 * scale_array[rand_idx_size]) + + sample_bbox_size = wid * resize_width / scale_choose + + w_off_orig = 0.0 + h_off_orig = 0.0 + if sample_bbox_size < max(image_height, image_width): + if wid <= sample_bbox_size: + w_off_orig = random.uniform(xmin + wid - sample_bbox_size, xmin) + else: + w_off_orig = random.uniform(xmin, xmin + wid - sample_bbox_size) + + if hei <= sample_bbox_size: + h_off_orig = random.uniform(ymin + hei - sample_bbox_size, ymin) + else: + h_off_orig = random.uniform(ymin, ymin + hei - sample_bbox_size) + + else: + w_off_orig = random.uniform(image_width - sample_bbox_size, 0.0) + h_off_orig = random.uniform(image_height - sample_bbox_size, 0.0) + + w_off_orig = math.floor(w_off_orig) + h_off_orig = math.floor(h_off_orig) + + # Figure out top left coordinates. + w_off = 0.0 + h_off = 0.0 + w_off = float(w_off_orig / image_width) + h_off = float(h_off_orig / image_height) + + sampled_bbox = bbox(w_off, h_off, + w_off + float(sample_bbox_size / image_width), + h_off + float(sample_bbox_size / image_height)) + return sampled_bbox + + def jaccard_overlap(sample_bbox, object_bbox): if sample_bbox.xmin >= object_bbox.xmax or \ sample_bbox.xmax <= object_bbox.xmin or \ @@ -161,8 +232,6 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): def generate_batch_samples(batch_sampler, bbox_labels, image_width, image_height): sampled_bbox = [] - index = [] - c = 0 for sampler in batch_sampler: found = 0 for i in range(sampler.max_trial): @@ -172,8 +241,24 @@ def generate_batch_samples(batch_sampler, bbox_labels, image_width, if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): sampled_bbox.append(sample_bbox) found = found + 1 - index.append(c) - c = c + 1 + return sampled_bbox + + +def generate_batch_random_samples(batch_sampler, bbox_labels, image_width, + image_height, scale_array, resize_width, + resize_height): + sampled_bbox = [] + for sampler in batch_sampler: + found = 0 + for i in range(sampler.max_trial): + if found >= sampler.max_sample: + break + sample_bbox = data_anchor_sampling( + sampler, bbox_labels, image_width, image_height, scale_array, + resize_width, resize_height) + if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): + sampled_bbox.append(sample_bbox) + found = found + 1 return sampled_bbox @@ -237,48 +322,117 @@ def transform_labels(bbox_labels, sample_bbox): return sample_labels -def crop_image(img, bbox_labels, sample_bbox, image_width, image_height): +def transform_labels_sampling(bbox_labels, sample_bbox, resize_val, + min_face_size): + sample_labels = [] + for i in range(len(bbox_labels)): + sample_label = [] + object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2], + bbox_labels[i][3], bbox_labels[i][4]) + if not meet_emit_constraint(object_bbox, sample_bbox): + continue + proj_bbox = project_bbox(object_bbox, sample_bbox) + if proj_bbox: + real_width = float((proj_bbox.xmax - proj_bbox.xmin) * resize_val) + real_height = float((proj_bbox.ymax - proj_bbox.ymin) * resize_val) + if real_width * real_height < float(min_face_size * min_face_size): + continue + else: + sample_label.append(bbox_labels[i][0]) + sample_label.append(float(proj_bbox.xmin)) + sample_label.append(float(proj_bbox.ymin)) + sample_label.append(float(proj_bbox.xmax)) + sample_label.append(float(proj_bbox.ymax)) + sample_label = sample_label + bbox_labels[i][5:] + sample_labels.append(sample_label) + return sample_labels + + +def crop_image(img, bbox_labels, sample_bbox, image_width, image_height, + resize_width, resize_height, min_face_size): sample_bbox = clip_bbox(sample_bbox) xmin = int(sample_bbox.xmin * image_width) xmax = int(sample_bbox.xmax * image_width) ymin = int(sample_bbox.ymin * image_height) ymax = int(sample_bbox.ymax * image_height) + sample_img = img[ymin:ymax, xmin:xmax] - sample_labels = transform_labels(bbox_labels, sample_bbox) + resize_val = resize_width + sample_labels = transform_labels_sampling(bbox_labels, sample_bbox, + resize_val, min_face_size) + return sample_img, sample_labels + + +def crop_image_sampling(img, bbox_labels, sample_bbox, image_width, + image_height, resize_width, resize_height, + min_face_size): + # no clipping here + xmin = int(sample_bbox.xmin * image_width) + xmax = int(sample_bbox.xmax * image_width) + ymin = int(sample_bbox.ymin * image_height) + ymax = int(sample_bbox.ymax * image_height) + + w_off = xmin + h_off = ymin + width = xmax - xmin + height = ymax - ymin + + cross_xmin = max(0.0, float(w_off)) + cross_ymin = max(0.0, float(h_off)) + cross_xmax = min(float(w_off + width - 1.0), float(image_width)) + cross_ymax = min(float(h_off + height - 1.0), float(image_height)) + cross_width = cross_xmax - cross_xmin + cross_height = cross_ymax - cross_ymin + + roi_xmin = 0 if w_off >= 0 else abs(w_off) + roi_ymin = 0 if h_off >= 0 else abs(h_off) + roi_width = cross_width + roi_height = cross_height + + sample_img = np.zeros((height, width, 3)) + sample_img[int(roi_ymin) : int(roi_ymin + roi_height), int(roi_xmin) : int(roi_xmin + roi_width)] = \ + img[int(cross_ymin) : int(cross_ymin + cross_height), int(cross_xmin) : int(cross_xmin + cross_width)] + + sample_img = cv2.resize( + sample_img, (resize_width, resize_height), interpolation=cv2.INTER_AREA) + + resize_val = resize_width + sample_labels = transform_labels_sampling(bbox_labels, sample_bbox, + resize_val, min_face_size) return sample_img, sample_labels def random_brightness(img, settings): prob = random.uniform(0, 1) - if prob < settings._brightness_prob: - delta = random.uniform(-settings._brightness_delta, - settings._brightness_delta) + 1 + if prob < settings.brightness_prob: + delta = random.uniform(-settings.brightness_delta, + settings.brightness_delta) + 1 img = ImageEnhance.Brightness(img).enhance(delta) return img def random_contrast(img, settings): prob = random.uniform(0, 1) - if prob < settings._contrast_prob: - delta = random.uniform(-settings._contrast_delta, - settings._contrast_delta) + 1 + if prob < settings.contrast_prob: + delta = random.uniform(-settings.contrast_delta, + settings.contrast_delta) + 1 img = ImageEnhance.Contrast(img).enhance(delta) return img def random_saturation(img, settings): prob = random.uniform(0, 1) - if prob < settings._saturation_prob: - delta = random.uniform(-settings._saturation_delta, - settings._saturation_delta) + 1 + if prob < settings.saturation_prob: + delta = random.uniform(-settings.saturation_delta, + settings.saturation_delta) + 1 img = ImageEnhance.Color(img).enhance(delta) return img def random_hue(img, settings): prob = random.uniform(0, 1) - if prob < settings._hue_prob: - delta = random.uniform(-settings._hue_delta, settings._hue_delta) + if prob < settings.hue_prob: + delta = random.uniform(-settings.hue_delta, settings.hue_delta) img_hsv = np.array(img.convert('HSV')) img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta img = Image.fromarray(img_hsv, mode='HSV').convert('RGB') @@ -303,9 +457,9 @@ def distort_image(img, settings): def expand_image(img, bbox_labels, img_width, img_height, settings): prob = random.uniform(0, 1) - if prob < settings._expand_prob: - if settings._expand_max_ratio - 1 >= 0.01: - expand_ratio = random.uniform(1, settings._expand_max_ratio) + if prob < settings.expand_prob: + if settings.expand_max_ratio - 1 >= 0.01: + expand_ratio = random.uniform(1, settings.expand_max_ratio) height = int(img_height * expand_ratio) width = int(img_width * expand_ratio) h_off = math.floor(random.uniform(0, height - img_height)) @@ -314,7 +468,7 @@ def expand_image(img, bbox_labels, img_width, img_height, settings): (width - w_off) / img_width, (height - h_off) / img_height) expand_img = np.ones((height, width, 3)) - expand_img = np.uint8(expand_img * np.squeeze(settings._img_mean)) + expand_img = np.uint8(expand_img * np.squeeze(settings.img_mean)) expand_img = Image.fromarray(expand_img) expand_img.paste(img, (int(w_off), int(h_off))) bbox_labels = transform_labels(bbox_labels, expand_bbox) diff --git a/fluid/face_detection/infer.py b/fluid/face_detection/infer.py index 71a878cb39f9888e3c308ee24e34dd6c3a073d33..a9468c33c110e04c82c9845414e1d83fee0bb7a7 100644 --- a/fluid/face_detection/infer.py +++ b/fluid/face_detection/infer.py @@ -15,7 +15,7 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('use_gpu', bool, True, "Whether use GPU.") -add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") +add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.") add_arg('image_path', str, '', "The data root path.") add_arg('model_dir', str, '', "The model path.") @@ -168,6 +168,9 @@ def detect_face(image, shrink): return_numpy=False) detection = np.array(detection) # layout: xmin, ymin, xmax. ymax, score + if detection.shape == (1, ): + print("No face detected") + return np.array([[0, 0, 0, 0, 0]]) det_conf = detection[:, 1] det_xmin = image_shape[2] * detection[:, 2] / shrink det_ymin = image_shape[1] * detection[:, 3] / shrink @@ -227,6 +230,33 @@ def multi_scale_test(image, max_shrink): return det_s, det_b +def multi_scale_test_pyramid(image, max_shrink): + # shrink detecting and shrink only detect big face + det_b = detect_face(image, 0.25) + index = np.where( + np.maximum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1) + > 30)[0] + det_b = det_b[index, :] + + st = [0.5, 0.75, 1.25, 1.5, 1.75, 2.25] + for i in range(len(st)): + if (st[i] <= max_shrink): + det_temp = detect_face(image, st[i]) + # enlarge only detect small face + if st[i] > 1: + index = np.where( + np.minimum(det_temp[:, 2] - det_temp[:, 0] + 1, + det_temp[:, 3] - det_temp[:, 1] + 1) < 100)[0] + det_temp = det_temp[index, :] + else: + index = np.where( + np.maximum(det_temp[:, 2] - det_temp[:, 0] + 1, + det_temp[:, 3] - det_temp[:, 1] + 1) > 30)[0] + det_temp = det_temp[index, :] + det_b = np.row_stack((det_b, det_temp)) + return det_b + + def get_im_shrink(image_shape): max_shrink_v1 = (0x7fffffff / 577.0 / (image_shape[1] * image_shape[2]))**0.5 @@ -272,7 +302,8 @@ def infer(args, batch_size, data_args): det0 = detect_face(image, shrink) det1 = flip_test(image, shrink) [det2, det3] = multi_scale_test(image, max_shrink) - det = np.row_stack((det0, det1, det2, det3)) + det4 = multi_scale_test_pyramid(image, max_shrink) + det = np.row_stack((det0, det1, det2, det3, det4)) dets = bbox_vote(det) image_name = image_path.split('/')[-1] diff --git a/fluid/face_detection/profile.py b/fluid/face_detection/profile.py new file mode 100644 index 0000000000000000000000000000000000000000..fd686ad0784abd730d41263e3982560345ca6908 --- /dev/null +++ b/fluid/face_detection/profile.py @@ -0,0 +1,190 @@ +import os +import shutil +import numpy as np +import time +import argparse +import functools + +import reader +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from pyramidbox import PyramidBox +from utility import add_arguments, print_arguments + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) + +# yapf: disable +add_arg('parallel', bool, True, "parallel") +add_arg('learning_rate', float, 0.001, "Learning rate.") +add_arg('batch_size', int, 20, "Minibatch size.") +add_arg('num_iteration', int, 10, "Epoch number.") +add_arg('skip_reader', bool, False, "Whether to skip data reader.") +add_arg('use_gpu', bool, True, "Whether use GPU.") +add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") +add_arg('model_save_dir', str, 'output', "The path to save model.") +add_arg('pretrained_model', str, './pretrained/', "The init model path.") +add_arg('resize_h', int, 640, "The resized image height.") +add_arg('resize_w', int, 640, "The resized image height.") +#yapf: enable + + +def train(args, config, train_file_list, optimizer_method): + learning_rate = args.learning_rate + batch_size = args.batch_size + height = args.resize_h + width = args.resize_w + use_gpu = args.use_gpu + use_pyramidbox = args.use_pyramidbox + model_save_dir = args.model_save_dir + pretrained_model = args.pretrained_model + skip_reader = args.skip_reader + num_iterations = args.num_iteration + parallel = args.parallel + + num_classes = 2 + image_shape = [3, height, width] + + devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" + devices_num = len(devices.split(",")) + + fetches = [] + network = PyramidBox(image_shape, num_classes, + sub_network=use_pyramidbox) + if use_pyramidbox: + face_loss, head_loss, loss = network.train() + fetches = [face_loss, head_loss] + else: + loss = network.vgg_ssd_loss() + fetches = [loss] + + epocs = 12880 / batch_size + boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100] + values = [ + learning_rate, learning_rate * 0.5, learning_rate * 0.25, + learning_rate * 0.1, learning_rate * 0.01 + ] + + if optimizer_method == "momentum": + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=boundaries, values=values), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(0.0005), + ) + else: + optimizer = fluid.optimizer.RMSProp( + learning_rate=fluid.layers.piecewise_decay(boundaries, values), + regularization=fluid.regularizer.L2Decay(0.0005), + ) + + optimizer.minimize(loss) + fluid.memory_optimize(fluid.default_main_program()) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + start_pass = 0 + if pretrained_model: + if pretrained_model.isdigit(): + start_pass = int(pretrained_model) + 1 + pretrained_model = os.path.join(model_save_dir, pretrained_model) + print("Resume from %s " %(pretrained_model)) + + if not os.path.exists(pretrained_model): + raise ValueError("The pre-trained model path [%s] does not exist." % + (pretrained_model)) + def if_exist(var): + return os.path.exists(os.path.join(pretrained_model, var.name)) + fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) + + if parallel: + train_exe = fluid.ParallelExecutor( + use_cuda=use_gpu, loss_name=loss.name) + + train_reader = reader.train_batch_reader(config, train_file_list, batch_size=batch_size) + + def tensor(data, place, lod=None): + t = fluid.core.LoDTensor() + t.set(data, place) + if lod: + t.set_lod(lod) + return t + + im, face_box, head_box, labels, lod = next(train_reader) + im_t = tensor(im, place) + box1 = tensor(face_box, place, [lod]) + box2 = tensor(head_box, place, [lod]) + lbl_t = tensor(labels, place, [lod]) + feed_data = {'image': im_t, 'face_box': box1, + 'head_box': box2, 'gt_label': lbl_t} + + def run(iterations, feed_data): + # global feed_data + reader_time = [] + run_time = [] + for batch_id in range(iterations): + start_time = time.time() + if not skip_reader: + im, face_box, head_box, labels, lod = next(train_reader) + im_t = tensor(im, place) + box1 = tensor(face_box, place, [lod]) + box2 = tensor(head_box, place, [lod]) + lbl_t = tensor(labels, place, [lod]) + feed_data = {'image': im_t, 'face_box': box1, + 'head_box': box2, 'gt_label': lbl_t} + end_time = time.time() + reader_time.append(end_time - start_time) + + start_time = time.time() + if parallel: + fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches], + feed=feed_data) + else: + fetch_vars = exe.run(fluid.default_main_program(), + feed=feed_data, + fetch_list=fetches) + end_time = time.time() + run_time.append(end_time - start_time) + fetch_vars = [np.mean(np.array(v)) for v in fetch_vars] + if not args.use_pyramidbox: + print("Batch {0}, loss {1}".format(batch_id, fetch_vars[0])) + else: + print("Batch {0}, face loss {1}, head loss {2}".format( + batch_id, fetch_vars[0], fetch_vars[1])) + + return reader_time, run_time + + # start-up + run(2, feed_data) + + # profiling + start = time.time() + if not parallel: + with profiler.profiler('All', 'total', '/tmp/profile_file'): + reader_time, run_time = run(num_iterations, feed_data) + else: + reader_time, run_time = run(num_iterations, feed_data) + end = time.time() + total_time = end - start + print("Total time: {0}, reader time: {1} s, run time: {2} s".format( + total_time, np.sum(reader_time), np.sum(run_time))) + + +if __name__ == '__main__': + args = parser.parse_args() + print_arguments(args) + + data_dir = 'data/WIDERFACE/WIDER_train/images/' + train_file_list = 'label/train_gt_widerface.res' + + config = reader.Settings( + data_dir=data_dir, + resize_h=args.resize_h, + resize_w=args.resize_w, + apply_expand=False, + mean_value=[104., 117., 123.], + ap_version='11point') + train(args, config, train_file_list, optimizer_method="momentum") diff --git a/fluid/face_detection/pyramidbox.py b/fluid/face_detection/pyramidbox.py index ce01cb7a113219e08d4deb2984d2a12b2590faa5..be28827ff4edfd5ddbc1e78cbef268a5629400cd 100644 --- a/fluid/face_detection/pyramidbox.py +++ b/fluid/face_detection/pyramidbox.py @@ -81,10 +81,7 @@ class PyramidBox(object): if self.is_infer: return [self.image] else: - return [ - self.image, self.face_box, self.head_box, self.gt_label, - self.difficult - ] + return [self.image, self.face_box, self.head_box, self.gt_label] def _input(self): self.image = fluid.layers.data( @@ -96,8 +93,6 @@ class PyramidBox(object): name='head_box', shape=[4], dtype='float32', lod_level=1) self.gt_label = fluid.layers.data( name='gt_label', shape=[1], dtype='int32', lod_level=1) - self.difficult = fluid.layers.data( - name='gt_difficult', shape=[1], dtype='int32', lod_level=1) def _vgg(self): self.conv1, self.pool1 = conv_block(self.image, 2, [64] * 2, [3] * 2) @@ -144,7 +139,8 @@ class PyramidBox(object): stride=2, groups=ch, param_attr=w_attr, - bias_attr=False) + bias_attr=False, + use_cudnn=True) else: upsampling = fluid.layers.resize_bilinear( conv1, out_shape=up_to.shape[2:]) @@ -385,6 +381,7 @@ class PyramidBox(object): self.box_vars, overlap_threshold=0.35, neg_overlap=0.35) + face_loss.persistable = True head_loss = fluid.layers.ssd_loss( self.head_mbox_loc, self.head_mbox_conf, @@ -394,9 +391,13 @@ class PyramidBox(object): self.box_vars, overlap_threshold=0.35, neg_overlap=0.35) + head_loss.persistable = True face_loss = fluid.layers.reduce_sum(face_loss) + face_loss.persistable = True head_loss = fluid.layers.reduce_sum(head_loss) + head_loss.persistable = True total_loss = face_loss + head_loss + total_loss.persistable = True return face_loss, head_loss, total_loss def infer(self, main_program=None): @@ -410,5 +411,8 @@ class PyramidBox(object): self.face_mbox_conf, self.prior_boxes, self.box_vars, - nms_threshold=0.45) + nms_threshold=0.3, + nms_top_k=5000, + keep_top_k=750, + score_threshold=0.05) return test_program, face_nmsed_out diff --git a/fluid/face_detection/reader.py b/fluid/face_detection/reader.py index 42109b1194cad071c6571ffa1eb590526a688033..836d78c2d093fa2cde4a4471495a0c1dcb9b94f3 100644 --- a/fluid/face_detection/reader.py +++ b/fluid/face_detection/reader.py @@ -22,6 +22,9 @@ import xml.etree.ElementTree import os import time import copy +import random +import cv2 +from data_util import GeneratorEnqueuer class Settings(object): @@ -36,112 +39,130 @@ class Settings(object): apply_expand=True, ap_version='11point', toy=0): - self._dataset = dataset - self._ap_version = ap_version - self._toy = toy - self._data_dir = data_dir - self._apply_distort = apply_distort - self._apply_expand = apply_expand - self._resize_height = resize_h - self._resize_width = resize_w - self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype( + self.dataset = dataset + self.ap_version = ap_version + self.toy = toy + self.data_dir = data_dir + self.apply_distort = apply_distort + self.apply_expand = apply_expand + self.resize_height = resize_h + self.resize_width = resize_w + self.img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype( 'float32') - self._expand_prob = 0.5 - self._expand_max_ratio = 4 - self._hue_prob = 0.5 - self._hue_delta = 18 - self._contrast_prob = 0.5 - self._contrast_delta = 0.5 - self._saturation_prob = 0.5 - self._saturation_delta = 0.5 - self._brightness_prob = 0.5 + self.expand_prob = 0.5 + self.expand_max_ratio = 4 + self.hue_prob = 0.5 + self.hue_delta = 18 + self.contrast_prob = 0.5 + self.contrast_delta = 0.5 + self.saturation_prob = 0.5 + self.saturation_delta = 0.5 + self.brightness_prob = 0.5 # _brightness_delta is the normalized value by 256 # self._brightness_delta = 32 - self._brightness_delta = 0.125 - - @property - def dataset(self): - return self._dataset - - @property - def ap_version(self): - return self._ap_version + self.brightness_delta = 0.125 + self.scale = 0.007843 # 1 / 127.5 + self.data_anchor_sampling_prob = 0.5 + self.min_face_size = 8.0 - @property - def toy(self): - return self._toy - @property - def apply_expand(self): - return self._apply_expand +def draw_image(faces_pred, img, resize_val): + for i in range(len(faces_pred)): + draw_rotate_rectange(img, faces_pred[i], resize_val, (0, 255, 0), 3) - @property - def apply_distort(self): - return self._apply_distort - @property - def data_dir(self): - return self._data_dir +def draw_rotate_rectange(img, face, resize_val, color, thickness): + cv2.line(img, (int(face[1] * resize_val), int(face[2] * resize_val)), (int( + face[3] * resize_val), int(face[2] * resize_val)), color, thickness) - @data_dir.setter - def data_dir(self, data_dir): - self._data_dir = data_dir + cv2.line(img, (int(face[3] * resize_val), int(face[2] * resize_val)), (int( + face[3] * resize_val), int(face[4] * resize_val)), color, thickness) - @property - def label_list(self): - return self._label_list + cv2.line(img, (int(face[1] * resize_val), int(face[2] * resize_val)), (int( + face[1] * resize_val), int(face[4] * resize_val)), color, thickness) - @property - def resize_h(self): - return self._resize_height + cv2.line(img, (int(face[3] * resize_val), int(face[4] * resize_val)), (int( + face[1] * resize_val), int(face[4] * resize_val)), color, thickness) - @property - def resize_w(self): - return self._resize_width - @property - def img_mean(self): - return self._img_mean - - -def preprocess(img, bbox_labels, mode, settings): +def preprocess(img, bbox_labels, mode, settings, image_path): img_width, img_height = img.size sampled_labels = bbox_labels if mode == 'train': - if settings._apply_distort: + if settings.apply_distort: img = image_util.distort_image(img, settings) - if settings._apply_expand: + if settings.apply_expand: img, bbox_labels, img_width, img_height = image_util.expand_image( img, bbox_labels, img_width, img_height, settings) + # sampling batch_sampler = [] - # hard-code here - batch_sampler.append( - image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, - True)) - batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, - True)) - batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, - True)) - batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, - True)) - batch_sampler.append( - image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, - True)) - sampled_bbox = image_util.generate_batch_samples( - batch_sampler, bbox_labels, img_width, img_height) - img = np.array(img) - if len(sampled_bbox) > 0: - idx = int(random.uniform(0, len(sampled_bbox))) - img, sampled_labels = image_util.crop_image( - img, bbox_labels, sampled_bbox[idx], img_width, img_height) - - img = Image.fromarray(img) - img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS) + prob = random.uniform(0., 1.) + if prob > settings.data_anchor_sampling_prob: + scale_array = np.array([16, 32, 64, 128, 256, 512]) + batch_sampler.append( + image_util.sampler(1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, + 0.0, True)) + sampled_bbox = image_util.generate_batch_random_samples( + batch_sampler, bbox_labels, img_width, img_height, scale_array, + settings.resize_width, settings.resize_height) + img = np.array(img) + # Debug + # img_save = Image.fromarray(img) + # img_save.save('img_orig.jpg') + if len(sampled_bbox) > 0: + idx = int(random.uniform(0, len(sampled_bbox))) + img, sampled_labels = image_util.crop_image_sampling( + img, bbox_labels, sampled_bbox[idx], img_width, img_height, + settings.resize_width, settings.resize_height, + settings.min_face_size) + + img = img.astype('uint8') + # Debug: visualize the gt bbox + visualize_bbox = 0 + if visualize_bbox: + img_show = img + draw_image(sampled_labels, img_show, settings.resize_height) + img_show = Image.fromarray(img_show) + img_show.save('final_img_show.jpg') + + img = Image.fromarray(img) + # Debug + # img.save('final_img.jpg') + + else: + # hard-code here + batch_sampler.append( + image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, + 0.0, True)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, + 0.0, True)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, + 0.0, True)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, + 0.0, True)) + batch_sampler.append( + image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, + 0.0, True)) + sampled_bbox = image_util.generate_batch_samples( + batch_sampler, bbox_labels, img_width, img_height) + + img = np.array(img) + if len(sampled_bbox) > 0: + idx = int(random.uniform(0, len(sampled_bbox))) + img, sampled_labels = image_util.crop_image( + img, bbox_labels, sampled_bbox[idx], img_width, img_height, + settings.resize_width, settings.resize_height, + settings.min_face_size) + + img = Image.fromarray(img) + + img = img.resize((settings.resize_width, settings.resize_height), + Image.ANTIALIAS) img = np.array(img) if mode == 'train': @@ -160,27 +181,26 @@ def preprocess(img, bbox_labels, mode, settings): img = img[[2, 1, 0], :, :] img = img.astype('float32') img -= settings.img_mean - img = img * 0.007843 + img = img * settings.scale return img, sampled_labels -def put_txt_in_dict(input_txt): +def load_file_list(input_txt): with open(input_txt, 'r') as f_dir: lines_input_txt = f_dir.readlines() - dict_input_txt = {} + file_dict = {} num_class = 0 for i in range(len(lines_input_txt)): tmp_line_txt = lines_input_txt[i].strip('\n\t\r') if '--' in tmp_line_txt: if i != 0: num_class += 1 - dict_input_txt[num_class] = [] + file_dict[num_class] = [] dict_name = tmp_line_txt - dict_input_txt[num_class].append(tmp_line_txt) + file_dict[num_class].append(tmp_line_txt) if '--' not in tmp_line_txt: if len(tmp_line_txt) > 6: - # tmp_line_txt = tmp_line_txt[:-2] split_str = tmp_line_txt.split(' ') x1_min = float(split_str[0]) y1_min = float(split_str[1]) @@ -188,11 +208,11 @@ def put_txt_in_dict(input_txt): y2_max = float(split_str[3]) tmp_line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str( x2_max) + ' ' + str(y2_max) - dict_input_txt[num_class].append(tmp_line_txt) + file_dict[num_class].append(tmp_line_txt) else: - dict_input_txt[num_class].append(tmp_line_txt) + file_dict[num_class].append(tmp_line_txt) - return dict_input_txt + return file_dict def expand_bboxes(bboxes, @@ -219,67 +239,106 @@ def expand_bboxes(bboxes, return expand_boxes -def pyramidbox(settings, file_list, mode, shuffle): - - dict_input_txt = {} - dict_input_txt = put_txt_in_dict(file_list) +def train_generator(settings, file_list, batch_size, shuffle=True): + file_dict = load_file_list(file_list) + while True: + if shuffle: + random.shuffle(file_dict) + images, face_boxes, head_boxes, label_ids = [], [], [], [] + label_offs = [0] - def reader(): - if mode == 'train' and shuffle: - random.shuffle(dict_input_txt) - for index_image in range(len(dict_input_txt)): - - image_name = dict_input_txt[index_image][0] + '.jpg' + for index_image in file_dict.keys(): + image_name = file_dict[index_image][0] + '.jpg' image_path = os.path.join(settings.data_dir, image_name) - im = Image.open(image_path) if im.mode == 'L': im = im.convert('RGB') im_width, im_height = im.size # layout: label | xmin | ymin | xmax | ymax - if mode == 'train': - bbox_labels = [] - for index_box in range(len(dict_input_txt[index_image])): - if index_box >= 2: - bbox_sample = [] - temp_info_box = dict_input_txt[index_image][ - index_box].split(' ') - xmin = float(temp_info_box[0]) - ymin = float(temp_info_box[1]) - w = float(temp_info_box[2]) - h = float(temp_info_box[3]) - xmax = xmin + w - ymax = ymin + h - - bbox_sample.append(1) - bbox_sample.append(float(xmin) / im_width) - bbox_sample.append(float(ymin) / im_height) - bbox_sample.append(float(xmax) / im_width) - bbox_sample.append(float(ymax) / im_height) - bbox_labels.append(bbox_sample) - - im, sample_labels = preprocess(im, bbox_labels, mode, settings) - sample_labels = np.array(sample_labels) - if len(sample_labels) == 0: continue - im = im.astype('float32') - boxes = sample_labels[:, 1:5] - lbls = [1] * len(boxes) - difficults = [1] * len(boxes) - yield im, boxes, expand_bboxes(boxes), lbls, difficults - - if mode == 'test': - yield im, image_path - - return reader + bbox_labels = [] + for index_box in range(len(file_dict[index_image])): + if index_box >= 2: + bbox_sample = [] + temp_info_box = file_dict[index_image][index_box].split(' ') + xmin = float(temp_info_box[0]) + ymin = float(temp_info_box[1]) + w = float(temp_info_box[2]) + h = float(temp_info_box[3]) + xmax = xmin + w + ymax = ymin + h + + bbox_sample.append(1) + bbox_sample.append(float(xmin) / im_width) + bbox_sample.append(float(ymin) / im_height) + bbox_sample.append(float(xmax) / im_width) + bbox_sample.append(float(ymax) / im_height) + bbox_labels.append(bbox_sample) + + im, sample_labels = preprocess(im, bbox_labels, "train", settings, + image_path) + sample_labels = np.array(sample_labels) + if len(sample_labels) == 0: continue + + im = im.astype('float32') + face_box = sample_labels[:, 1:5] + head_box = expand_bboxes(face_box) + label = [1] * len(face_box) + + images.append(im) + face_boxes.extend(face_box) + head_boxes.extend(head_box) + label_ids.extend(label) + label_offs.append(label_offs[-1] + len(face_box)) + + if len(images) == batch_size: + images = np.array(images).astype('float32') + face_boxes = np.array(face_boxes).astype('float32') + head_boxes = np.array(head_boxes).astype('float32') + label_ids = np.array(label_ids).astype('int32') + yield images, face_boxes, head_boxes, label_ids, label_offs + images, face_boxes, head_boxes = [], [], [] + label_ids, label_offs = [], [0] + + +def train_batch_reader(settings, + file_list, + batch_size, + shuffle=True, + num_workers=8): + try: + enqueuer = GeneratorEnqueuer( + train_generator(settings, file_list, batch_size, shuffle), + use_multiprocessing=False) + enqueuer.start(max_queue_size=24, workers=num_workers) + generator_output = None + while True: + while enqueuer.is_running(): + if not enqueuer.queue.empty(): + generator_output = enqueuer.queue.get() + break + else: + time.sleep(0.01) + yield generator_output + generator_output = None + finally: + if enqueuer is not None: + enqueuer.stop() -def train(settings, file_list, shuffle=True): - return pyramidbox(settings, file_list, 'train', shuffle) +def test(settings, file_list): + file_dict = load_file_list(file_list) + def reader(): + for index_image in file_dict.keys(): + image_name = file_dict[index_image][0] + '.jpg' + image_path = os.path.join(settings.data_dir, image_name) + im = Image.open(image_path) + if im.mode == 'L': + im = im.convert('RGB') + yield im, image_path -def test(settings, file_list): - return pyramidbox(settings, file_list, 'test', False) + return reader def infer(settings, image_path): @@ -288,8 +347,8 @@ def infer(settings, image_path): if img.mode == 'L': img = im.convert('RGB') im_width, im_height = img.size - if settings.resize_w and settings.resize_h: - img = img.resize((settings.resize_w, settings.resize_h), + if settings.resize_width and settings.resize_height: + img = img.resize((settings.resize_width, settings.resize_height), Image.ANTIALIAS) img = np.array(img) # HWC to CHW @@ -300,9 +359,7 @@ def infer(settings, image_path): img = img[[2, 1, 0], :, :] img = img.astype('float32') img -= settings.img_mean - img = img * 0.007843 - img = [img] - img = np.array(img) - return img + img = img * settings.scale + return np.array([img]) return batch_reader diff --git a/fluid/face_detection/train.py b/fluid/face_detection/train.py index c10722b9e33d6c9d05f961d3b2cf73a859b9da3c..ecb9d76d1c81f03984756ef16348f5f1bef5943b 100644 --- a/fluid/face_detection/train.py +++ b/fluid/face_detection/train.py @@ -15,42 +15,52 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('parallel', bool, True, "parallel") -add_arg('learning_rate', float, 0.001, "Learning rate.") -add_arg('batch_size', int, 12, "Minibatch size.") -add_arg('num_passes', int, 120, "Epoch number.") -add_arg('use_gpu', bool, True, "Whether use GPU.") -add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") -add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.") -add_arg('model_save_dir', str, 'model', "The path to save model.") -add_arg('pretrained_model', str, './pretrained/', "The init model path.") -add_arg('resize_h', int, 640, "The resized image height.") -add_arg('resize_w', int, 640, "The resized image height.") +add_arg('parallel', bool, True, "parallel") +add_arg('learning_rate', float, 0.001, "Learning rate.") +add_arg('batch_size', int, 12, "Minibatch size.") +add_arg('num_passes', int, 160, "Epoch number.") +add_arg('use_gpu', bool, True, "Whether use GPU.") +add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") +add_arg('model_save_dir', str, 'output', "The path to save model.") +add_arg('pretrained_model', str, './pretrained/', "The init model path.") +add_arg('resize_h', int, 640, "The resized image height.") +add_arg('resize_w', int, 640, "The resized image height.") +add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.") #yapf: enable -def train(args, data_args, learning_rate, batch_size, pretrained_model, - num_passes, optimizer_method): +def train(args, config, train_file_list, optimizer_method): + learning_rate = args.learning_rate + batch_size = args.batch_size + num_passes = args.num_passes + height = args.resize_h + width = args.resize_w + use_gpu = args.use_gpu + use_pyramidbox = args.use_pyramidbox + model_save_dir = args.model_save_dir + pretrained_model = args.pretrained_model + with_memory_optimization = args.with_mem_opt num_classes = 2 + image_shape = [3, height, width] devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" devices_num = len(devices.split(",")) - image_shape = [3, data_args.resize_h, data_args.resize_w] fetches = [] network = PyramidBox(image_shape, num_classes, - sub_network=args.use_pyramidbox) - if args.use_pyramidbox: + sub_network=use_pyramidbox) + if use_pyramidbox: face_loss, head_loss, loss = network.train() fetches = [face_loss, head_loss] else: loss = network.vgg_ssd_loss() fetches = [loss] - epocs = 12880 / batch_size - boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100] + steps_per_pass = 12880 / batch_size + boundaries = [steps_per_pass * 50, steps_per_pass * 80, + steps_per_pass * 120, steps_per_pass * 140] values = [ learning_rate, learning_rate * 0.5, learning_rate * 0.25, learning_rate * 0.1, learning_rate * 0.01 @@ -70,9 +80,10 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ) optimizer.minimize(loss) - # fluid.memory_optimize(fluid.default_main_program()) + if with_memory_optimization: + fluid.memory_optimize(fluid.default_main_program()) - place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) @@ -80,7 +91,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, if pretrained_model: if pretrained_model.isdigit(): start_pass = int(pretrained_model) + 1 - pretrained_model = os.path.join(args.model_save_dir, pretrained_model) + pretrained_model = os.path.join(model_save_dir, pretrained_model) print("Resume from %s " %(pretrained_model)) if not os.path.exists(pretrained_model): @@ -92,11 +103,9 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, if args.parallel: train_exe = fluid.ParallelExecutor( - use_cuda=args.use_gpu, loss_name=loss.name) + use_cuda=use_gpu, loss_name=loss.name) - train_reader = paddle.batch( - reader.train(data_args, train_file_list), batch_size=batch_size) - feeder = fluid.DataFeeder(place=place, feed_list=network.feeds()) + train_reader = reader.train_batch_reader(config, train_file_list, batch_size=batch_size) def save_model(postfix): model_path = os.path.join(model_save_dir, postfix) @@ -105,20 +114,34 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, print 'save models to %s' % (model_path) fluid.io.save_persistables(exe, model_path) + def tensor(data, place, lod=None): + t = fluid.core.LoDTensor() + t.set(data, place) + if lod: + t.set_lod(lod) + return t + for pass_id in range(start_pass, num_passes): start_time = time.time() prev_start_time = start_time end_time = 0 - for batch_id, data in enumerate(train_reader()): + for batch_id in range(steps_per_pass): + im, face_box, head_box, labels, lod = next(train_reader) + im_t = tensor(im, place) + box1 = tensor(face_box, place, [lod]) + box2 = tensor(head_box, place, [lod]) + lbl_t = tensor(labels, place, [lod]) + feeding = {'image': im_t, 'face_box': box1, + 'head_box': box2, 'gt_label': lbl_t} + prev_start_time = start_time start_time = time.time() - if len(data) < 2 * devices_num: continue if args.parallel: fetch_vars = train_exe.run(fetch_list=[v.name for v in fetches], - feed=feeder.feed(data)) + feed=feeding) else: fetch_vars = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), + feed=feeding, fetch_list=fetches) end_time = time.time() fetch_vars = [np.mean(np.array(v)) for v in fetch_vars] @@ -143,22 +166,13 @@ if __name__ == '__main__': data_dir = 'data/WIDERFACE/WIDER_train/images/' train_file_list = 'label/train_gt_widerface.res' - val_file_list = 'label/val_gt_widerface.res' - model_save_dir = args.model_save_dir - data_args = reader.Settings( - dataset=args.dataset, + config = reader.Settings( data_dir=data_dir, resize_h=args.resize_h, resize_w=args.resize_w, + apply_distort=True, apply_expand=False, - mean_value=[104., 117., 123], + mean_value=[104., 117., 123.], ap_version='11point') - train( - args, - data_args=data_args, - learning_rate=args.learning_rate, - batch_size=args.batch_size, - pretrained_model=args.pretrained_model, - num_passes=args.num_passes, - optimizer_method="momentum") + train(args, config, train_file_list, optimizer_method="momentum") diff --git a/fluid/icnet/README.md b/fluid/icnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dc350ff5e66993b33b976018df36369b773a90c3 --- /dev/null +++ b/fluid/icnet/README.md @@ -0,0 +1,110 @@ +运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。 + + +## 代码结构 +``` +├── network.py # 网络结构定义脚本 +├── train.py # 训练任务脚本 +├── eval.py # 评估脚本 +├── infer.py # 预测脚本 +├── cityscape.py # 数据预处理脚本 +└── utils.py # 定义通用的函数 +``` + +## 简介 + +Image Cascade Network(ICNet)主要用于图像实时语义分割。相较于其它压缩计算的方法,ICNet即考虑了速度,也考虑了准确性。 +ICNet的主要思想是将输入图像变换为不同的分辨率,然后用不同计算复杂度的子网络计算不同分辨率的输入,然后将结果合并。ICNet由三个子网络组成,计算复杂度高的网络处理低分辨率输入,计算复杂度低的网络处理分辨率高的网络,通过这种方式在高分辨率图像的准确性和低复杂度网络的效率之间获得平衡。 + +整个网络结构如下: + +
+
+图 1
+
+
+图 2
+
+
+图 3
+
@@ -35,12 +34,12 @@
在训练集中,每张图片对应的label是汉字在词典中的索引。 `图1` 对应的label如下所示:
```
-3835,8371,7191,2369,6876,4162,1938,168,1517,4590,3793
+80,84,68,82,83,72,78,77,68,67
```
-在上边这个label中,`3835` 表示字符‘两’的索引,`4590` 表示中文字符逗号的索引。
+在上边这个label中,`80` 表示字符`Q`的索引,`67` 表示英文字符`D`的索引。
-#### 1.1.2 数据准备
+### 数据准备
**A. 训练集**
@@ -105,7 +104,9 @@ data/test_images/00003.jpg
第三种:从stdin读入一张图片的path,然后进行一次inference.
-#### 1.2 训练
+## 模型训练与预测
+
+### 训练
使用默认数据在GPU单卡上训练:
@@ -121,7 +122,7 @@ env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True
执行`python ctc_train.py --help`可查看更多使用方式和参数详细说明。
-图2为使用默认参数和默认数据集训练的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。在45轮迭代训练中,测试集上最低错误率为第60轮的21.11%.
+图2为使用默认参数和默认数据集训练的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。在60轮迭代训练中,测试集上最低错误率为第32轮的22.0%.
@@ -130,7 +131,7 @@ env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True
-### 1.3 评估
+## 测试
通过以下命令调用评估脚本用指定数据集对模型进行评估:
@@ -144,7 +145,7 @@ env CUDA_VISIBLE_DEVICE=0 python eval.py \
执行`python ctc_train.py --help`可查看参数详细说明。
-### 1.4 预测
+### 预测
从标准输入读取一张图片的路径,并对齐进行预测:
@@ -176,5 +177,3 @@ env CUDA_VISIBLE_DEVICE=0 python infer.py \
--model_path="models/model_00044_15000" \
--input_images_list="data/test.list"
```
-
->注意:因为版权原因,我们暂时停止提供中文数据集的下载和使用服务,你通过`ctc_reader.py`自动下载的数据将是含有30W图片的英文数据集。在英文数据集上的训练结果会稍后发布。
diff --git a/fluid/ocr_recognition/crnn_ctc_model.py b/fluid/ocr_recognition/crnn_ctc_model.py
index 1e687d2aa53c0c43a7b491a61b60fd2432210c95..79cf7b23954ce3331f46c50ee165dac720deae43 100644
--- a/fluid/ocr_recognition/crnn_ctc_model.py
+++ b/fluid/ocr_recognition/crnn_ctc_model.py
@@ -1,4 +1,7 @@
import paddle.fluid as fluid
+from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
+from paddle.fluid.initializer import init_on_cpu
+import math
def conv_bn_pool(input,
@@ -8,7 +11,8 @@ def conv_bn_pool(input,
param=None,
bias=None,
param_0=None,
- is_test=False):
+ is_test=False,
+ pooling=True):
tmp = input
for i in xrange(group):
tmp = fluid.layers.conv2d(
@@ -19,32 +23,25 @@ def conv_bn_pool(input,
param_attr=param if param_0 is None else param_0,
act=None, # LinearActivation
use_cudnn=True)
- #tmp = fluid.layers.Print(tmp)
tmp = fluid.layers.batch_norm(
input=tmp,
act=act,
param_attr=param,
bias_attr=bias,
is_test=is_test)
- tmp = fluid.layers.pool2d(
- input=tmp,
- pool_size=2,
- pool_type='max',
- pool_stride=2,
- use_cudnn=True,
- ceil_mode=True)
+ if pooling:
+ tmp = fluid.layers.pool2d(
+ input=tmp,
+ pool_size=2,
+ pool_type='max',
+ pool_stride=2,
+ use_cudnn=True,
+ ceil_mode=True)
return tmp
-def ocr_convs(input,
- num,
- with_bn,
- regularizer=None,
- gradient_clip=None,
- is_test=False):
- assert (num % 4 == 0)
-
+def ocr_convs(input, regularizer=None, gradient_clip=None, is_test=False):
b = fluid.ParamAttr(
regularizer=regularizer,
gradient_clip=gradient_clip,
@@ -63,7 +60,8 @@ def ocr_convs(input,
tmp = conv_bn_pool(tmp, 2, [32, 32], param=w1, bias=b, is_test=is_test)
tmp = conv_bn_pool(tmp, 2, [64, 64], param=w1, bias=b, is_test=is_test)
- tmp = conv_bn_pool(tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test)
+ tmp = conv_bn_pool(
+ tmp, 2, [128, 128], param=w1, bias=b, is_test=is_test, pooling=False)
return tmp
@@ -75,8 +73,6 @@ def encoder_net(images,
is_test=False):
conv_features = ocr_convs(
images,
- 8,
- True,
regularizer=regularizer,
gradient_clip=gradient_clip,
is_test=is_test)
@@ -143,6 +139,7 @@ def ctc_train_net(images, label, args, num_classes):
L2_RATE = 0.0004
LR = 1.0e-3
MOMENTUM = 0.9
+ learning_rate_decay = None
regularizer = fluid.regularizer.L2Decay(L2_RATE)
fc_out = encoder_net(images, num_classes, regularizer=regularizer)
@@ -155,7 +152,15 @@ def ctc_train_net(images, label, args, num_classes):
error_evaluator = fluid.evaluator.EditDistance(
input=decoded_out, label=casted_label)
inference_program = fluid.default_main_program().clone(for_test=True)
- optimizer = fluid.optimizer.Momentum(learning_rate=LR, momentum=MOMENTUM)
+ if learning_rate_decay == "piecewise_decay":
+ learning_rate = fluid.layers.piecewise_decay([
+ args.total_step / 4, args.total_step / 2, args.total_step * 3 / 4
+ ], [LR, LR * 0.1, LR * 0.01, LR * 0.001])
+ else:
+ learning_rate = LR
+
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=learning_rate, momentum=MOMENTUM)
_, params_grads = optimizer.minimize(sum_cost)
model_average = None
if args.average_window > 0:
diff --git a/fluid/ocr_recognition/ctc_reader.py b/fluid/ocr_recognition/ctc_reader.py
index ae8912b36933f6165babb8fb866bee5e074da850..db05dbeae73b67b12aebacdc84a04d5b180d2132 100644
--- a/fluid/ocr_recognition/ctc_reader.py
+++ b/fluid/ocr_recognition/ctc_reader.py
@@ -7,7 +7,7 @@ from os import path
from paddle.v2.image import load_image
import paddle.v2 as paddle
-NUM_CLASSES = 10784
+NUM_CLASSES = 95
DATA_SHAPE = [1, 48, 512]
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
diff --git a/fluid/ocr_recognition/ctc_train.py b/fluid/ocr_recognition/ctc_train.py
index 9a1f5d9bad16be95715f8599fb38e4ea63aeeac8..dde07e51887ab6d7724f9b8893ae49479ee7b9a7 100644
--- a/fluid/ocr_recognition/ctc_train.py
+++ b/fluid/ocr_recognition/ctc_train.py
@@ -14,7 +14,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
-add_arg('pass_num', int, 100, "Number of training epochs.")
+add_arg('total_step', int, 720000, "Number of training iterations.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
@@ -22,7 +22,7 @@ add_arg('save_model_dir', str, "./models", "The directory the model to be s
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
-add_arg('max_average_window',int, 15625, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
+add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
# yapf: enable
@@ -90,54 +90,57 @@ def train(args, data_reader=ctc_reader):
results = [result[0] for result in results]
return results
- def test(pass_id, batch_id):
+ def test(iter_num):
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
- print "\nTime: %s; Pass[%d]-batch[%d]; Test seq error: %s.\n" % (
- time.time(), pass_id, batch_id, str(test_seq_error[0]))
+ print "\nTime: %s; Iter[%d]; Test seq error: %s.\n" % (
+ time.time(), iter_num, str(test_seq_error[0]))
- def save_model(args, exe, pass_id, batch_id):
- filename = "model_%05d_%d" % (pass_id, batch_id)
+ def save_model(args, exe, iter_num):
+ filename = "model_%05d" % iter_num
fluid.io.save_params(
exe, dirname=args.save_model_dir, filename=filename)
print "Saved model to: %s/%s." % (args.save_model_dir, filename)
- for pass_id in range(args.pass_num):
- batch_id = 1
+ iter_num = 0
+ while True:
total_loss = 0.0
total_seq_error = 0.0
# train a pass
for data in train_reader():
+ iter_num += 1
+ if iter_num > args.total_step:
+ return
results = train_one_batch(data)
total_loss += results[0]
total_seq_error += results[2]
# training log
- if batch_id % args.log_period == 0:
- print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq err: %s" % (
- time.time(), pass_id, batch_id,
- total_loss / (batch_id * args.batch_size),
- total_seq_error / (batch_id * args.batch_size))
+ if iter_num % args.log_period == 0:
+ print "\nTime: %s; Iter[%d]; Avg Warp-CTC loss: %.3f; Avg seq err: %.3f" % (
+ time.time(), iter_num,
+ total_loss / (args.log_period * args.batch_size),
+ total_seq_error / (args.log_period * args.batch_size))
sys.stdout.flush()
+ total_loss = 0.0
+ total_seq_error = 0.0
# evaluate
- if batch_id % args.eval_period == 0:
+ if iter_num % args.eval_period == 0:
if model_average:
with model_average.apply(exe):
- test(pass_id, batch_id)
+ test(iter_num)
else:
- test(pass_id, batch_d)
+ test(iter_num)
# save model
- if batch_id % args.save_model_period == 0:
+ if iter_num % args.save_model_period == 0:
if model_average:
with model_average.apply(exe):
- save_model(args, exe, pass_id, batch_id)
+ save_model(args, exe, iter_num)
else:
- save_model(args, exe, pass_id, batch_id)
-
- batch_id += 1
+ save_model(args, exe, iter_num)
def main():
diff --git a/fluid/ocr_recognition/eval.py b/fluid/ocr_recognition/eval.py
index be0a04380b62b274abfa954cbeed451afb441922..6924131686a1387a55cdf85136da39a249a369a7 100644
--- a/fluid/ocr_recognition/eval.py
+++ b/fluid/ocr_recognition/eval.py
@@ -35,7 +35,7 @@ def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
# prepare environment
place = fluid.CPUPlace()
- if use_gpu:
+ if args.use_gpu:
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
diff --git a/fluid/ocr_recognition/images/demo.jpg b/fluid/ocr_recognition/images/demo.jpg
index be5aee506f68861583903d04c526523afc299ab8..d2f8a24afbe862c51f913043df8d8c8be49b521b 100644
Binary files a/fluid/ocr_recognition/images/demo.jpg and b/fluid/ocr_recognition/images/demo.jpg differ
diff --git a/fluid/ocr_recognition/images/train.jpg b/fluid/ocr_recognition/images/train.jpg
index ec86fb1bf828699b3b63926accad0e943f25feeb..71301ef59615ee8983ac19fef836b7f338818324 100644
Binary files a/fluid/ocr_recognition/images/train.jpg and b/fluid/ocr_recognition/images/train.jpg differ