From 26b3788b4b7e2b66027951f98975731dc9914dad Mon Sep 17 00:00:00 2001 From: Yuan Gao Date: Thu, 12 Apr 2018 12:23:02 +0800 Subject: [PATCH] Fix data aug (#829) * fix expand data aug * fix expand data aug --- fluid/object_detection/image_util.py | 11 +++++------ fluid/object_detection/reader.py | 4 +--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/fluid/object_detection/image_util.py b/fluid/object_detection/image_util.py index e538449a..b8464cfe 100644 --- a/fluid/object_detection/image_util.py +++ b/fluid/object_detection/image_util.py @@ -85,8 +85,7 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): return False -def generate_batch_samples(batch_sampler, bbox_labels, image_width, - image_height): +def generate_batch_samples(batch_sampler, bbox_labels): sampled_bbox = [] index = [] c = 0 @@ -217,8 +216,8 @@ def distort_image(img, settings): def expand_image(img, bbox_labels, img_width, img_height, settings): prob = random.uniform(0, 1) if prob < settings._expand_prob: - expand_ratio = random.uniform(1, settings._expand_max_ratio) - if expand_ratio - 1 >= 0.01: + if _expand_max_ratio - 1 >= 0.01: + expand_ratio = random.uniform(1, settings._expand_max_ratio) height = int(img_height * expand_ratio) width = int(img_width * expand_ratio) h_off = math.floor(random.uniform(0, height - img_height)) @@ -231,5 +230,5 @@ def expand_image(img, bbox_labels, img_width, img_height, settings): expand_img = Image.fromarray(expand_img) expand_img.paste(img, (int(w_off), int(h_off))) bbox_labels = transform_labels(bbox_labels, expand_bbox) - return expand_img, bbox_labels - return img, bbox_labels + return expand_img, bbox_labels, width, height + return img, bbox_labels, img_width, img_height diff --git a/fluid/object_detection/reader.py b/fluid/object_detection/reader.py index 47f78f63..43c54b4c 100644 --- a/fluid/object_detection/reader.py +++ b/fluid/object_detection/reader.py @@ -193,7 +193,7 @@ def _reader_creator(settings, file_list, mode, shuffle): if settings._apply_distort: img = image_util.distort_image(img, settings) if settings._apply_expand: - img, bbox_labels = image_util.expand_image( + img, bbox_labels, img_width, img_height = image_util.expand_image( img, bbox_labels, img_width, img_height, settings) batch_sampler = [] # hard-code here @@ -236,7 +236,6 @@ def _reader_creator(settings, file_list, mode, shuffle): sample_labels[i][1] = 1 - sample_labels[i][3] sample_labels[i][3] = 1 - tmp - #draw_bounding_box_on_image(img, sample_labels, image_name, category_names, normalized=True) # HWC to CHW if len(img.shape) == 3: img = np.swapaxes(img, 1, 2) @@ -287,7 +286,6 @@ def draw_bounding_box_on_image(image, (left, top)], width=thickness, fill=color) - #draw.rectangle([xmin, ymin, xmax, ymax], outline=color) if with_text: if image.mode == 'RGB': draw.text((left, top), category_name, (255, 255, 0)) -- GitLab