diff --git a/paddlex/__init__.py b/paddlex/__init__.py index b80363f2e6adfdbd6ce712cfec486540753abbb7..c4b47a690771dc4bcc2d3d7ff84c2b359d27cd48 100644 --- a/paddlex/__init__.py +++ b/paddlex/__init__.py @@ -48,6 +48,7 @@ if hub.version.hub_version < '1.6.2': env_info = get_environ_info() load_model = cv.models.load_model datasets = cv.datasets +transforms = cv.transforms log_level = 2 diff --git a/paddlex/cv/transforms/__init__.py b/paddlex/cv/transforms/__init__.py index 37c14e75f72f8c6b76a608116419d58437fab99e..c74b5b19e8d1e007674f6d17a30736f42dde1789 100644 --- a/paddlex/cv/transforms/__init__.py +++ b/paddlex/cv/transforms/__init__.py @@ -15,3 +15,5 @@ from . import cls_transforms from . import det_transforms from . import seg_transforms +from . import visualize +visualize = visualize.visualize diff --git a/paddlex/cv/transforms/cls_transforms.py b/paddlex/cv/transforms/cls_transforms.py index 3a504af0395eb144f46afd0a3e6b11d74434c031..29f62252d899a1a235e74047268bbc3076e423dd 100644 --- a/paddlex/cv/transforms/cls_transforms.py +++ b/paddlex/cv/transforms/cls_transforms.py @@ -58,24 +58,8 @@ class Compose(ClsTransform): raise Exception( "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - self.images_writer = None - - def set_vdl(self, vdl_save_dir=None): - # 对数据预处理结果在VisualDL中可视化 - self.images_writer = None - if vdl_save_dir is not None: - if not osp.isdir(vdl_save_dir): - if osp.exists(vdl_save_dir): - os.remove(vdl_save_dir) - os.makedirs(vdl_save_dir) - from visualdl import LogWriter - vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms') - self.images_writer = LogWriter(vdl_images_dir) - - def release_vdl(self): - self.images_writer = None - - def __call__(self, im, label=None, step=0): + + def __call__(self, im, label=None, images_writer=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -84,7 +68,6 @@ class Compose(ClsTransform): tuple: 根据网络所需字段所组成的tuple; 字段由transforms中的最后一个数据预处理操作决定。 """ - im_file = str(step) if isinstance(im, np.ndarray): if len(im.shape) != 3: raise Exception( @@ -92,15 +75,14 @@ class Compose(ClsTransform): format(len(im.shape))) else: try: - im_file = im im = cv2.imread(im).astype('float32') except: raise TypeError('Can\'t read The image file {}!'.format(im)) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) - if self.images_writer is not None: - self.images_writer.add_image(tag='0. origin image', - img=im, - step=step) + if images_writer is not None: + images_writer.add_image(tag='0. origin image', + img=im, + step=step) op_id = 1 for op in self.transforms: if isinstance(op, ClsTransform): @@ -115,9 +97,9 @@ class Compose(ClsTransform): outputs = (im, ) if label is not None: outputs = (im, label) - if self.images_writer is not None: + if images_writer is not None: tag = str(op_id) + '. ' + op.__class__.__name__ - self.images_writer.add_image(tag=tag, + images_writer.add_image(tag=tag, img=im, step=step) op_id += 1 diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index 4a1683fc8a9eb2177e2976f905a2b86d3d524a5f..a4c50924ce74459859eaeb661c6fa4ebf26ddcb6 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -51,7 +51,7 @@ class Compose(DetTransform): ValueError: 数据长度不匹配。 """ - def __init__(self, transforms, vdl_save_dir=None): + def __init__(self, transforms): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') if len(transforms) < 1: @@ -70,24 +70,8 @@ class Compose(DetTransform): raise Exception( "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - self.images_writer = None - - def set_vdl(self, vdl_save_dir=None): - # 对数据预处理结果在VisualDL中可视化 - self.images_writer = None - if vdl_save_dir is not None: - if not osp.isdir(vdl_save_dir): - if osp.exists(vdl_save_dir): - os.remove(vdl_save_dir) - os.makedirs(vdl_save_dir) - from visualdl import LogWriter - vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms') - self.images_writer = LogWriter(vdl_images_dir) - - def release_vdl(self): - self.images_writer = None - def __call__(self, im, im_info=None, label_info=None, step=0): + def __call__(self, im, im_info=None, label_info=None, images_writer=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -151,19 +135,15 @@ class Compose(DetTransform): else: return (im, im_info, label_info) - if isinstance(im, str): - im_file = im - else: - im_file = str(step) outputs = decode_image(im, im_info, label_info) im = outputs[0] im_info = outputs[1] if len(outputs) == 3: label_info = outputs[2] - if self.images_writer is not None: - self.images_writer.add_image(tag='0. origin image', - img=im, - step=step) + if images_writer is not None: + images_writer.add_image(tag='0. origin image', + img=im, + step=step) op_id = 1 for op in self.transforms: if im is None: @@ -177,9 +157,9 @@ class Compose(DetTransform): outputs = (im, im_info, label_info) else: outputs = (im, im_info) - if self.images_writer is not None: + if images_writer is not None: tag = str(op_id) + '. ' + op.__class__.__name__ - self.images_writer.add_image(tag=tag, + images_writer.add_image(tag=tag, img=im, step=step) op_id += 1 diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py index 8d51343b626d583d21f3c82c66f1f792216ce9af..7a21281a17e7028e1311a0fb6f02c9c741ff1460 100644 --- a/paddlex/cv/transforms/seg_transforms.py +++ b/paddlex/cv/transforms/seg_transforms.py @@ -46,7 +46,7 @@ class Compose(SegTransform): """ - def __init__(self, transforms, vdl_save_dir=None): + def __init__(self, transforms): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') if len(transforms) < 1: @@ -62,24 +62,8 @@ class Compose(SegTransform): raise Exception( "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - self.images_writer = None - - def set_vdl(self, vdl_save_dir=None): - # 对数据预处理结果在VisualDL中可视化 - self.images_writer = None - if vdl_save_dir is not None: - if not osp.isdir(vdl_save_dir): - if osp.exists(vdl_save_dir): - os.remove(vdl_save_dir) - os.makedirs(vdl_save_dir) - from visualdl import LogWriter - vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms') - self.images_writer = LogWriter(vdl_images_dir) - - def release_vdl(self): - self.images_writer = None - - def __call__(self, im, im_info=None, label=None, step=0): + + def __call__(self, im, im_info=None, label=None, images_writer=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -92,7 +76,6 @@ class Compose(SegTransform): Returns: tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 """ - im_file = str(step) if im_info is None: im_info = list() if isinstance(im, np.ndarray): @@ -102,7 +85,6 @@ class Compose(SegTransform): format(len(im.shape))) else: try: - im_file = im im = cv2.imread(im).astype('float32') except: raise ValueError('Can\'t read The image file {}!'.format(im)) @@ -111,10 +93,10 @@ class Compose(SegTransform): if label is not None: if not isinstance(label, np.ndarray): label = np.asarray(Image.open(label)) - if self.images_writer is not None: - self.images_writer.add_image(tag='0. origin image', - img=im, - step=step) + if images_writer is not None: + images_writer.add_image(tag='0. origin image', + img=im, + step=step) op_id = 1 for op in self.transforms: if isinstance(op, SegTransform): @@ -130,9 +112,9 @@ class Compose(SegTransform): outputs = (im, im_info, label) else: outputs = (im, im_info) - if self.images_writer is not None: + if images_writer is not None: tag = str(op_id) + '. ' + op.__class__.__name__ - self.images_writer.add_image(tag=tag, + images_writer.add_image(tag=tag, img=im, step=step) op_id += 1 diff --git a/paddlex/cv/transforms/visualize.py b/paddlex/cv/transforms/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..02df46ca7e9e67f33908f6a14e2be57675de8f27 --- /dev/null +++ b/paddlex/cv/transforms/visualize.py @@ -0,0 +1,39 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path as osp +from .cls_transforms import ClsTransform +from .det_transforms import DetTransform +from .seg_transforms import SegTransform + +def visualize(dataset, index=0, steps=3, save_dir='vdl_output'): + transforms = dataset.transforms + if not osp.isdir(save_dir): + if osp.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + for i, data in enumerate(dataset.iterator()): + if i == index: + break + from visualdl import LogWriter + vdl_save_dir = osp.join(save_dir, 'image_transforms') + images_writer = LogWriter(vdl_save_dir) + data.append(images_writer) + for s in range(steps): + if s != 0: + data.pop() + data.append(s) + transforms(*data) + \ No newline at end of file diff --git a/tutorials/train/classification/mobilenetv2.py b/tutorials/train/classification/mobilenetv2.py index 7412eb1d1c142588b0489af3f209d51fb3a1c01e..3f637125b760de6d992d6a062e4d456bf5038426 100644 --- a/tutorials/train/classification/mobilenetv2.py +++ b/tutorials/train/classification/mobilenetv2.py @@ -34,19 +34,6 @@ eval_dataset = pdx.datasets.ImageNet( label_list='vegetables_cls/labels.txt', transforms=eval_transforms) -# 可使用VisualDL查看数据预处理的中间结果 -# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 -# 浏览器打开 https://0.0.0.0:8001即可 -# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP -train_transforms.set_vdl(vdl_save_dir='vdl_output') -for step, data in enumerate(train_dataset.iterator()): - data.append(step) - train_transforms(*data) - if step == 5: - break -train_transforms.release_vdl() - - # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标 # VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001 diff --git a/tutorials/train/detection/yolov3_darknet53.py b/tutorials/train/detection/yolov3_darknet53.py index 2f09f45de924e0bb0aba3b4be659d9475775b0e2..c38656b04e9a35cd033dc583811c58aa8baafba2 100644 --- a/tutorials/train/detection/yolov3_darknet53.py +++ b/tutorials/train/detection/yolov3_darknet53.py @@ -38,18 +38,6 @@ eval_dataset = pdx.datasets.VOCDetection( label_list='insect_det/labels.txt', transforms=eval_transforms) -# 可使用VisualDL查看数据预处理的中间结果 -# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 -# 浏览器打开 https://0.0.0.0:8001即可 -# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP -train_transforms.set_vdl(vdl_save_dir='vdl_output') -for step, data in enumerate(train_dataset.iterator()): - data.append(step) - train_transforms(*data) - if step == 5: - break -train_transforms.release_vdl() - # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标 # VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001 diff --git a/tutorials/train/segmentation/deeplabv3p.py b/tutorials/train/segmentation/deeplabv3p.py index 701d2a1a69b7f58f0a503ed8e2675c048c3bf39e..346a229a358a76830112acfd596740c070822874 100644 --- a/tutorials/train/segmentation/deeplabv3p.py +++ b/tutorials/train/segmentation/deeplabv3p.py @@ -33,18 +33,6 @@ eval_dataset = pdx.datasets.SegDataset( label_list='optic_disc_seg/labels.txt', transforms=eval_transforms) -# 可使用VisualDL查看数据预处理的中间结果 -# VisualDL启动方式: visualdl --logdir vdl_output --port 8001 -# 浏览器打开 https://0.0.0.0:8001即可 -# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP -train_transforms.vdl_save_dir = 'vdl_output' -for step, data in enumerate(train_dataset.iterator()): - data.append(step) - train_transforms(*data) - if step == 5: - break -train_transforms.vdl_save_dir = None - # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标 # VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001