From 4cbf8369c48d3f33c3a5179d8210049447efb98f Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Tue, 16 Jun 2020 15:45:12 +0800 Subject: [PATCH] add transforms vdl --- paddlex/cv/transforms/cls_transforms.py | 35 +++++++++++++++- paddlex/cv/transforms/det_transforms.py | 41 +++++++++++++++++-- paddlex/cv/transforms/seg_transforms.py | 38 +++++++++++++++-- tutorials/train/classification/mobilenetv2.py | 13 ++++++ tutorials/train/detection/yolov3_darknet53.py | 24 +++++++++++ tutorials/train/segmentation/deeplabv3p.py | 12 ++++++ 6 files changed, 153 insertions(+), 10 deletions(-) diff --git a/paddlex/cv/transforms/cls_transforms.py b/paddlex/cv/transforms/cls_transforms.py index dbcd342..3a504af 100644 --- a/paddlex/cv/transforms/cls_transforms.py +++ b/paddlex/cv/transforms/cls_transforms.py @@ -15,6 +15,7 @@ from .ops import * from .imgaug_support import execute_imgaug import random +import os import os.path as osp import numpy as np from PIL import Image, ImageEnhance @@ -57,8 +58,24 @@ class Compose(ClsTransform): raise Exception( "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - - def __call__(self, im, label=None): + self.images_writer = None + + def set_vdl(self, vdl_save_dir=None): + # 对数据预处理结果在VisualDL中可视化 + self.images_writer = None + if vdl_save_dir is not None: + if not osp.isdir(vdl_save_dir): + if osp.exists(vdl_save_dir): + os.remove(vdl_save_dir) + os.makedirs(vdl_save_dir) + from visualdl import LogWriter + vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms') + self.images_writer = LogWriter(vdl_images_dir) + + def release_vdl(self): + self.images_writer = None + + def __call__(self, im, label=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -67,6 +84,7 @@ class Compose(ClsTransform): tuple: 根据网络所需字段所组成的tuple; 字段由transforms中的最后一个数据预处理操作决定。 """ + im_file = str(step) if isinstance(im, np.ndarray): if len(im.shape) != 3: raise Exception( @@ -74,10 +92,16 @@ class Compose(ClsTransform): format(len(im.shape))) else: try: + im_file = im im = cv2.imread(im).astype('float32') except: raise TypeError('Can\'t read The image file {}!'.format(im)) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + if self.images_writer is not None: + self.images_writer.add_image(tag='0. origin image', + img=im, + step=step) + op_id = 1 for op in self.transforms: if isinstance(op, ClsTransform): outputs = op(im, label) @@ -91,6 +115,12 @@ class Compose(ClsTransform): outputs = (im, ) if label is not None: outputs = (im, label) + if self.images_writer is not None: + tag = str(op_id) + '. ' + op.__class__.__name__ + self.images_writer.add_image(tag=tag, + img=im, + step=step) + op_id += 1 return outputs def add_augmenters(self, augmenters): @@ -434,6 +464,7 @@ class RandomDistort(ClsTransform): params['im'] = im if np.random.uniform(0, 1) < prob: im = ops[id](**params) + im = im.astype('float32') if label is None: return (im, ) else: diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index 45eff25..4a1683f 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -18,6 +18,7 @@ except Exception: from collections import Sequence import random +import os import os.path as osp import numpy as np @@ -50,7 +51,7 @@ class Compose(DetTransform): ValueError: 数据长度不匹配。 """ - def __init__(self, transforms): + def __init__(self, transforms, vdl_save_dir=None): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') if len(transforms) < 1: @@ -69,8 +70,24 @@ class Compose(DetTransform): raise Exception( "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - - def __call__(self, im, im_info=None, label_info=None): + self.images_writer = None + + def set_vdl(self, vdl_save_dir=None): + # 对数据预处理结果在VisualDL中可视化 + self.images_writer = None + if vdl_save_dir is not None: + if not osp.isdir(vdl_save_dir): + if osp.exists(vdl_save_dir): + os.remove(vdl_save_dir) + os.makedirs(vdl_save_dir) + from visualdl import LogWriter + vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms') + self.images_writer = LogWriter(vdl_images_dir) + + def release_vdl(self): + self.images_writer = None + + def __call__(self, im, im_info=None, label_info=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -133,12 +150,21 @@ class Compose(DetTransform): return (im, im_info) else: return (im, im_info, label_info) - + + if isinstance(im, str): + im_file = im + else: + im_file = str(step) outputs = decode_image(im, im_info, label_info) im = outputs[0] im_info = outputs[1] if len(outputs) == 3: label_info = outputs[2] + if self.images_writer is not None: + self.images_writer.add_image(tag='0. origin image', + img=im, + step=step) + op_id = 1 for op in self.transforms: if im is None: return None @@ -151,6 +177,12 @@ class Compose(DetTransform): outputs = (im, im_info, label_info) else: outputs = (im, im_info) + if self.images_writer is not None: + tag = str(op_id) + '. ' + op.__class__.__name__ + self.images_writer.add_image(tag=tag, + img=im, + step=step) + op_id += 1 return outputs def add_augmenters(self, augmenters): @@ -621,6 +653,7 @@ class RandomDistort(DetTransform): if np.random.uniform(0, 1) < prob: im = ops[id](**params) + im = im.astype('float32') if label_info is None: return (im, im_info) else: diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py index 9ea1c3b..8d51343 100644 --- a/paddlex/cv/transforms/seg_transforms.py +++ b/paddlex/cv/transforms/seg_transforms.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from .ops import * from .imgaug_support import execute_imgaug import random @@ -45,7 +46,7 @@ class Compose(SegTransform): """ - def __init__(self, transforms): + def __init__(self, transforms, vdl_save_dir=None): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') if len(transforms) < 1: @@ -61,8 +62,24 @@ class Compose(SegTransform): raise Exception( "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - - def __call__(self, im, im_info=None, label=None): + self.images_writer = None + + def set_vdl(self, vdl_save_dir=None): + # 对数据预处理结果在VisualDL中可视化 + self.images_writer = None + if vdl_save_dir is not None: + if not osp.isdir(vdl_save_dir): + if osp.exists(vdl_save_dir): + os.remove(vdl_save_dir) + os.makedirs(vdl_save_dir) + from visualdl import LogWriter + vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms') + self.images_writer = LogWriter(vdl_images_dir) + + def release_vdl(self): + self.images_writer = None + + def __call__(self, im, im_info=None, label=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -75,7 +92,7 @@ class Compose(SegTransform): Returns: tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 """ - + im_file = str(step) if im_info is None: im_info = list() if isinstance(im, np.ndarray): @@ -85,6 +102,7 @@ class Compose(SegTransform): format(len(im.shape))) else: try: + im_file = im im = cv2.imread(im).astype('float32') except: raise ValueError('Can\'t read The image file {}!'.format(im)) @@ -93,6 +111,11 @@ class Compose(SegTransform): if label is not None: if not isinstance(label, np.ndarray): label = np.asarray(Image.open(label)) + if self.images_writer is not None: + self.images_writer.add_image(tag='0. origin image', + img=im, + step=step) + op_id = 1 for op in self.transforms: if isinstance(op, SegTransform): outputs = op(im, im_info, label) @@ -107,6 +130,12 @@ class Compose(SegTransform): outputs = (im, im_info, label) else: outputs = (im, im_info) + if self.images_writer is not None: + tag = str(op_id) + '. ' + op.__class__.__name__ + self.images_writer.add_image(tag=tag, + img=im, + step=step) + op_id += 1 return outputs def add_augmenters(self, augmenters): @@ -1053,6 +1082,7 @@ class RandomDistort(SegTransform): params['im'] = im if np.random.uniform(0, 1) < prob: im = ops[id](**params) + im = im.astype('float32') if label is None: return (im, im_info) else: diff --git a/tutorials/train/classification/mobilenetv2.py b/tutorials/train/classification/mobilenetv2.py index 3f63712..7412eb1 100644 --- a/tutorials/train/classification/mobilenetv2.py +++ b/tutorials/train/classification/mobilenetv2.py @@ -34,6 +34,19 @@ eval_dataset = pdx.datasets.ImageNet( label_list='vegetables_cls/labels.txt', transforms=eval_transforms) +# 可使用VisualDL查看数据预处理的中间结果 +# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 +# 浏览器打开 https://0.0.0.0:8001即可 +# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP +train_transforms.set_vdl(vdl_save_dir='vdl_output') +for step, data in enumerate(train_dataset.iterator()): + data.append(step) + train_transforms(*data) + if step == 5: + break +train_transforms.release_vdl() + + # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标 # VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001 diff --git a/tutorials/train/detection/yolov3_darknet53.py b/tutorials/train/detection/yolov3_darknet53.py index c38656b..5511b82 100644 --- a/tutorials/train/detection/yolov3_darknet53.py +++ b/tutorials/train/detection/yolov3_darknet53.py @@ -38,6 +38,30 @@ eval_dataset = pdx.datasets.VOCDetection( label_list='insect_det/labels.txt', transforms=eval_transforms) +# 可使用VisualDL查看数据预处理的中间结果 +# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 +# 浏览器打开 https://0.0.0.0:8001即可 +# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP +train_transforms.set_vdl(vdl_save_dir='vdl_output') +for step, data in enumerate(train_dataset.iterator()): + data.append(step) + train_transforms(*data) + if step == 5: + break +train_transforms.release_vdl() + +# 可使用VisualDL查看数据预处理的中间结果 +# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 +# 浏览器打开 https://0.0.0.0:8001即可 +# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP +train_transforms.vdl_save_dir = 'vdl_output' +for step, data in enumerate(train_dataset.iterator()): + data.append(step) + train_transforms(*data) + if step == 5: + break +train_transforms.vdl_save_dir = None + # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标 # VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001 diff --git a/tutorials/train/segmentation/deeplabv3p.py b/tutorials/train/segmentation/deeplabv3p.py index 346a229..701d2a1 100644 --- a/tutorials/train/segmentation/deeplabv3p.py +++ b/tutorials/train/segmentation/deeplabv3p.py @@ -33,6 +33,18 @@ eval_dataset = pdx.datasets.SegDataset( label_list='optic_disc_seg/labels.txt', transforms=eval_transforms) +# 可使用VisualDL查看数据预处理的中间结果 +# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 +# 浏览器打开 https://0.0.0.0:8001即可 +# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP +train_transforms.vdl_save_dir = 'vdl_output' +for step, data in enumerate(train_dataset.iterator()): + data.append(step) + train_transforms(*data) + if step == 5: + break +train_transforms.vdl_save_dir = None + # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标 # VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001 -- GitLab