From b88fa96f2bff7702554fa8ae751ca045bddba2e3 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Tue, 16 Jun 2020 17:42:56 +0800 Subject: [PATCH] modify the name --- docs/apis/visualize.md | 19 +++++++++++++++++++ paddlex/cv/transforms/cls_transforms.py | 22 +++++++++++++--------- paddlex/cv/transforms/det_transforms.py | 22 +++++++++++++--------- paddlex/cv/transforms/seg_transforms.py | 21 ++++++++++++--------- paddlex/cv/transforms/visualize.py | 16 ++++++++++++++-- 5 files changed, 71 insertions(+), 29 deletions(-) diff --git a/docs/apis/visualize.md b/docs/apis/visualize.md index 0699132..3800688 100755 --- a/docs/apis/visualize.md +++ b/docs/apis/visualize.md @@ -165,3 +165,22 @@ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会 ### 使用示例 > 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。 + +## 数据预处理/增强过程可视化 +``` +paddlex.transforms.visualize(dataset, + index=0, + steps=3, + save_dir='vdl_output') +``` +对数据预处理/增强中间结果进行可视化。 +可使用VisualDL查看中间结果: +1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001 +2. 浏览器打开 https://0.0.0.0:8001即可, + 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP + +### 参数 +>* **dataset** (paddlex.datasets): 数据集读取器。 +>* **index** (int): 对数据集中的第index张图像进行可视化。默认为0 +>* **steps** (int): 数据预处理/增强的次数。默认为3。 +>* **save_dir** (str): 日志保存的路径。默认为'vdl_output'。 \ No newline at end of file diff --git a/paddlex/cv/transforms/cls_transforms.py b/paddlex/cv/transforms/cls_transforms.py index 29f6225..f87292b 100644 --- a/paddlex/cv/transforms/cls_transforms.py +++ b/paddlex/cv/transforms/cls_transforms.py @@ -59,11 +59,15 @@ class Compose(ClsTransform): "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - def __call__(self, im, label=None, images_writer=None, step=0): + def __call__(self, im, label=None, vdl_writer=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 label (int): 每张图像所对应的类别序号。 + vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。 + 当为None时,不对日志进行保存。默认为None。 + step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。 + Returns: tuple: 根据网络所需字段所组成的tuple; 字段由transforms中的最后一个数据预处理操作决定。 @@ -79,10 +83,10 @@ class Compose(ClsTransform): except: raise TypeError('Can\'t read The image file {}!'.format(im)) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) - if images_writer is not None: - images_writer.add_image(tag='0. origin image', - img=im, - step=step) + if vdl_writer is not None: + vdl_writer.add_image(tag='0. origin image', + img=im, + step=step) op_id = 1 for op in self.transforms: if isinstance(op, ClsTransform): @@ -97,11 +101,11 @@ class Compose(ClsTransform): outputs = (im, ) if label is not None: outputs = (im, label) - if images_writer is not None: + if vdl_writer is not None: tag = str(op_id) + '. ' + op.__class__.__name__ - images_writer.add_image(tag=tag, - img=im, - step=step) + vdl_writer.add_image(tag=tag, + img=im, + step=step) op_id += 1 return outputs diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index a4c5092..3f57c21 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -71,7 +71,7 @@ class Compose(DetTransform): "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - def __call__(self, im, im_info=None, label_info=None, images_writer=None, step=0): + def __call__(self, im, im_info=None, label_info=None, vdl_writer=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -95,6 +95,10 @@ class Compose(DetTransform): 其中n代表真实标注框的个数。 - difficult (np.ndarray): 每个真实标注框中的对象是否为难识别对象,形状为(n, 1), 其中n代表真实标注框的个数。 + vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。 + 当为None时,不对日志进行保存。默认为None。 + step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。 + Returns: tuple: 根据网络所需字段所组成的tuple; 字段由transforms中的最后一个数据预处理操作决定。 @@ -140,10 +144,10 @@ class Compose(DetTransform): im_info = outputs[1] if len(outputs) == 3: label_info = outputs[2] - if images_writer is not None: - images_writer.add_image(tag='0. origin image', - img=im, - step=step) + if vdl_writer is not None: + vdl_writer.add_image(tag='0. origin image', + img=im, + step=step) op_id = 1 for op in self.transforms: if im is None: @@ -157,11 +161,11 @@ class Compose(DetTransform): outputs = (im, im_info, label_info) else: outputs = (im, im_info) - if images_writer is not None: + if vdl_writer is not None: tag = str(op_id) + '. ' + op.__class__.__name__ - images_writer.add_image(tag=tag, - img=im, - step=step) + vdl_writer.add_image(tag=tag, + img=im, + step=step) op_id += 1 return outputs diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py index 7a21281..0eade48 100644 --- a/paddlex/cv/transforms/seg_transforms.py +++ b/paddlex/cv/transforms/seg_transforms.py @@ -63,7 +63,7 @@ class Compose(SegTransform): "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" ) - def __call__(self, im, im_info=None, label=None, images_writer=None, step=0): + def __call__(self, im, im_info=None, label=None, vdl_writer=None, step=0): """ Args: im (str/np.ndarray): 图像路径/图像np.ndarray数据。 @@ -72,6 +72,9 @@ class Compose(SegTransform): 图像在过resize前shape为(200, 300), 过padding前shape为 (400, 600) label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。 + vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。 + 当为None时,不对日志进行保存。默认为None。 + step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。 Returns: tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 @@ -93,10 +96,10 @@ class Compose(SegTransform): if label is not None: if not isinstance(label, np.ndarray): label = np.asarray(Image.open(label)) - if images_writer is not None: - images_writer.add_image(tag='0. origin image', - img=im, - step=step) + if vdl_writer is not None: + vdl_writer.add_image(tag='0. origin image', + img=im, + step=step) op_id = 1 for op in self.transforms: if isinstance(op, SegTransform): @@ -112,11 +115,11 @@ class Compose(SegTransform): outputs = (im, im_info, label) else: outputs = (im, im_info) - if images_writer is not None: + if vdl_writer is not None: tag = str(op_id) + '. ' + op.__class__.__name__ - images_writer.add_image(tag=tag, - img=im, - step=step) + vdl_writer.add_image(tag=tag, + img=im, + step=step) op_id += 1 return outputs diff --git a/paddlex/cv/transforms/visualize.py b/paddlex/cv/transforms/visualize.py index 02df46c..ace8112 100644 --- a/paddlex/cv/transforms/visualize.py +++ b/paddlex/cv/transforms/visualize.py @@ -19,6 +19,18 @@ from .det_transforms import DetTransform from .seg_transforms import SegTransform def visualize(dataset, index=0, steps=3, save_dir='vdl_output'): + '''对数据预处理/增强中间结果进行可视化。 + 可使用VisualDL查看中间结果: + 1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001 + 2. 浏览器打开 https://0.0.0.0:8001即可, + 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP + + Args: + dataset (paddlex.datasets): 数据集读取器。 + index (int): 对数据集中的第index张图像进行可视化。默认为0 + steps (int): 数据预处理/增强的次数。默认为3。 + save_dir (str): 日志保存的路径。默认为'vdl_output'。 + ''' transforms = dataset.transforms if not osp.isdir(save_dir): if osp.exists(save_dir): @@ -29,8 +41,8 @@ def visualize(dataset, index=0, steps=3, save_dir='vdl_output'): break from visualdl import LogWriter vdl_save_dir = osp.join(save_dir, 'image_transforms') - images_writer = LogWriter(vdl_save_dir) - data.append(images_writer) + vdl_writer = LogWriter(vdl_save_dir) + data.append(vdl_writer) for s in range(steps): if s != 0: data.pop() -- GitLab