未验证 提交 4eaaf5bc 编写于 作者: G Guanghua Yu 提交者: GitHub

support visualdl in preprocess data (#1828)

* support visualdl in preprocess data

* support most preprocess op
上级 16735ddb
...@@ -33,9 +33,10 @@ import random ...@@ -33,9 +33,10 @@ import random
import math import math
import numpy as np import numpy as np
import os import os
import six
import cv2 import cv2
from PIL import Image, ImageEnhance, ImageDraw from PIL import Image, ImageEnhance, ImageDraw, ImageOps
from ppdet.core.workspace import serializable from ppdet.core.workspace import serializable
from ppdet.modeling.ops import AnchorGrid from ppdet.modeling.ops import AnchorGrid
...@@ -2535,22 +2536,65 @@ class DebugVisibleImage(BaseOperator): ...@@ -2535,22 +2536,65 @@ class DebugVisibleImage(BaseOperator):
(Currently only supported when not cropping and flipping image.) (Currently only supported when not cropping and flipping image.)
""" """
def __init__(self, output_dir='output/debug', is_normalized=False): def __init__(self,
output_dir='output/debug',
use_vdl=False,
is_normalized=False):
super(DebugVisibleImage, self).__init__() super(DebugVisibleImage, self).__init__()
self.is_normalized = is_normalized self.is_normalized = is_normalized
self.output_dir = output_dir self.output_dir = output_dir
self.use_vdl = use_vdl
if not os.path.isdir(output_dir): if not os.path.isdir(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
if not isinstance(self.is_normalized, bool): if not isinstance(self.is_normalized, bool):
raise TypeError("{}: input type is invalid.".format(self)) raise TypeError("{}: input type is invalid.".format(self))
if self.use_vdl:
assert six.PY3, "VisualDL requires Python >= 3.5"
from visualdl import LogWriter
self.vdl_writer = LogWriter(self.output_dir)
def __call__(self, sample, context=None): def __call__(self, sample, context=None):
image = Image.open(sample['im_file']).convert('RGB')
out_file_name = sample['im_file'].split('/')[-1] out_file_name = sample['im_file'].split('/')[-1]
if self.use_vdl:
origin_image = Image.open(sample['im_file']).convert('RGB')
origin_image = ImageOps.exif_transpose(origin_image)
image_np = np.array(origin_image)
self.vdl_writer.add_image("original/{}".format(out_file_name),
image_np, 0)
if not isinstance(sample['image'], np.ndarray):
raise TypeError("{}: sample[image] type is not numpy.".format(self))
image = Image.fromarray(np.uint8(sample['image']))
width = sample['w'] width = sample['w']
height = sample['h'] height = sample['h']
gt_bbox = sample['gt_bbox'] gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class'] gt_class = sample['gt_class']
if 'gt_poly' in sample.keys():
poly_to_mask = Poly2Mask()
sample = poly_to_mask(sample)
if 'gt_segm' in sample.keys():
import pycocotools.mask as mask_util
from ppdet.utils.colormap import colormap
image_np = np.array(image).astype('float32')
mask_color_id = 0
w_ratio = .4
alpha = 0.7
color_list = colormap(rgb=True)
gt_segm = sample['gt_segm']
for mask in gt_segm:
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio
) + w_ratio * 255
idx = np.nonzero(mask)
image_np[idx[0], idx[1], :] *= 1.0 - alpha
image_np[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(np.uint8(image_np))
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
for i in range(gt_bbox.shape[0]): for i in range(gt_bbox.shape[0]):
if self.is_normalized: if self.is_normalized:
...@@ -2566,7 +2610,7 @@ class DebugVisibleImage(BaseOperator): ...@@ -2566,7 +2610,7 @@ class DebugVisibleImage(BaseOperator):
width=2, width=2,
fill='green') fill='green')
# draw label # draw label
text = str(gt_class[i][0]) text = 'id' + str(gt_class[i][0])
tw, th = draw.textsize(text) tw, th = draw.textsize(text)
draw.rectangle( draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green') [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green')
...@@ -2583,12 +2627,17 @@ class DebugVisibleImage(BaseOperator): ...@@ -2583,12 +2627,17 @@ class DebugVisibleImage(BaseOperator):
for i in range(gt_keypoint.shape[0]): for i in range(gt_keypoint.shape[0]):
keypoint = gt_keypoint[i] keypoint = gt_keypoint[i]
for j in range(int(keypoint.shape[0] / 2)): for j in range(int(keypoint.shape[0] / 2)):
x1 = round(keypoint[2 * j]).astype(np.int32) x1 = round(keypoint[2 * j])
y1 = round(keypoint[2 * j + 1]).astype(np.int32) y1 = round(keypoint[2 * j + 1])
draw.ellipse( draw.ellipse(
(x1, y1, x1 + 5, y1 + 5), fill='green', outline='green') (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green')
save_path = os.path.join(self.output_dir, out_file_name) save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95) if self.use_vdl:
preprocess_image_np = np.array(image)
self.vdl_writer.add_image("preprocess/{}".format(out_file_name),
preprocess_image_np, 0)
else:
image.save(save_path, quality=95)
return sample return sample
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册