提交 b88fa96f 编写于 作者: S sunyanfang01

modify the name

上级 de059b20
...@@ -165,3 +165,22 @@ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会 ...@@ -165,3 +165,22 @@ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会
### 使用示例 ### 使用示例
> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。 > 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。
## 数据预处理/增强过程可视化
```
paddlex.transforms.visualize(dataset,
index=0,
steps=3,
save_dir='vdl_output')
```
对数据预处理/增强中间结果进行可视化。
可使用VisualDL查看中间结果:
1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001
2. 浏览器打开 https://0.0.0.0:8001即可,
其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
### 参数
>* **dataset** (paddlex.datasets): 数据集读取器。
>* **index** (int): 对数据集中的第index张图像进行可视化。默认为0
>* **steps** (int): 数据预处理/增强的次数。默认为3。
>* **save_dir** (str): 日志保存的路径。默认为'vdl_output'。
\ No newline at end of file
...@@ -59,11 +59,15 @@ class Compose(ClsTransform): ...@@ -59,11 +59,15 @@ class Compose(ClsTransform):
"Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
def __call__(self, im, label=None, images_writer=None, step=0): def __call__(self, im, label=None, vdl_writer=None, step=0):
""" """
Args: Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。 im (str/np.ndarray): 图像路径/图像np.ndarray数据。
label (int): 每张图像所对应的类别序号。 label (int): 每张图像所对应的类别序号。
vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。
当为None时,不对日志进行保存。默认为None。
step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。
Returns: Returns:
tuple: 根据网络所需字段所组成的tuple; tuple: 根据网络所需字段所组成的tuple;
字段由transforms中的最后一个数据预处理操作决定。 字段由transforms中的最后一个数据预处理操作决定。
...@@ -79,10 +83,10 @@ class Compose(ClsTransform): ...@@ -79,10 +83,10 @@ class Compose(ClsTransform):
except: except:
raise TypeError('Can\'t read The image file {}!'.format(im)) raise TypeError('Can\'t read The image file {}!'.format(im))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if images_writer is not None: if vdl_writer is not None:
images_writer.add_image(tag='0. origin image', vdl_writer.add_image(tag='0. origin image',
img=im, img=im,
step=step) step=step)
op_id = 1 op_id = 1
for op in self.transforms: for op in self.transforms:
if isinstance(op, ClsTransform): if isinstance(op, ClsTransform):
...@@ -97,11 +101,11 @@ class Compose(ClsTransform): ...@@ -97,11 +101,11 @@ class Compose(ClsTransform):
outputs = (im, ) outputs = (im, )
if label is not None: if label is not None:
outputs = (im, label) outputs = (im, label)
if images_writer is not None: if vdl_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__ tag = str(op_id) + '. ' + op.__class__.__name__
images_writer.add_image(tag=tag, vdl_writer.add_image(tag=tag,
img=im, img=im,
step=step) step=step)
op_id += 1 op_id += 1
return outputs return outputs
......
...@@ -71,7 +71,7 @@ class Compose(DetTransform): ...@@ -71,7 +71,7 @@ class Compose(DetTransform):
"Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
def __call__(self, im, im_info=None, label_info=None, images_writer=None, step=0): def __call__(self, im, im_info=None, label_info=None, vdl_writer=None, step=0):
""" """
Args: Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。 im (str/np.ndarray): 图像路径/图像np.ndarray数据。
...@@ -95,6 +95,10 @@ class Compose(DetTransform): ...@@ -95,6 +95,10 @@ class Compose(DetTransform):
其中n代表真实标注框的个数。 其中n代表真实标注框的个数。
- difficult (np.ndarray): 每个真实标注框中的对象是否为难识别对象,形状为(n, 1), - difficult (np.ndarray): 每个真实标注框中的对象是否为难识别对象,形状为(n, 1),
其中n代表真实标注框的个数。 其中n代表真实标注框的个数。
vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。
当为None时,不对日志进行保存。默认为None。
step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。
Returns: Returns:
tuple: 根据网络所需字段所组成的tuple; tuple: 根据网络所需字段所组成的tuple;
字段由transforms中的最后一个数据预处理操作决定。 字段由transforms中的最后一个数据预处理操作决定。
...@@ -140,10 +144,10 @@ class Compose(DetTransform): ...@@ -140,10 +144,10 @@ class Compose(DetTransform):
im_info = outputs[1] im_info = outputs[1]
if len(outputs) == 3: if len(outputs) == 3:
label_info = outputs[2] label_info = outputs[2]
if images_writer is not None: if vdl_writer is not None:
images_writer.add_image(tag='0. origin image', vdl_writer.add_image(tag='0. origin image',
img=im, img=im,
step=step) step=step)
op_id = 1 op_id = 1
for op in self.transforms: for op in self.transforms:
if im is None: if im is None:
...@@ -157,11 +161,11 @@ class Compose(DetTransform): ...@@ -157,11 +161,11 @@ class Compose(DetTransform):
outputs = (im, im_info, label_info) outputs = (im, im_info, label_info)
else: else:
outputs = (im, im_info) outputs = (im, im_info)
if images_writer is not None: if vdl_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__ tag = str(op_id) + '. ' + op.__class__.__name__
images_writer.add_image(tag=tag, vdl_writer.add_image(tag=tag,
img=im, img=im,
step=step) step=step)
op_id += 1 op_id += 1
return outputs return outputs
......
...@@ -63,7 +63,7 @@ class Compose(SegTransform): ...@@ -63,7 +63,7 @@ class Compose(SegTransform):
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
def __call__(self, im, im_info=None, label=None, images_writer=None, step=0): def __call__(self, im, im_info=None, label=None, vdl_writer=None, step=0):
""" """
Args: Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。 im (str/np.ndarray): 图像路径/图像np.ndarray数据。
...@@ -72,6 +72,9 @@ class Compose(SegTransform): ...@@ -72,6 +72,9 @@ class Compose(SegTransform):
图像在过resize前shape为(200, 300), 过padding前shape为 图像在过resize前shape为(200, 300), 过padding前shape为
(400, 600) (400, 600)
label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。 label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。
vdl_writer (visualdl.LogWriter): VisualDL存储器,日志信息将保存在其中。
当为None时,不对日志进行保存。默认为None。
step (int): 数据预处理的轮数,当vdl_writer不为None时有效。默认为0。
Returns: Returns:
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
...@@ -93,10 +96,10 @@ class Compose(SegTransform): ...@@ -93,10 +96,10 @@ class Compose(SegTransform):
if label is not None: if label is not None:
if not isinstance(label, np.ndarray): if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label)) label = np.asarray(Image.open(label))
if images_writer is not None: if vdl_writer is not None:
images_writer.add_image(tag='0. origin image', vdl_writer.add_image(tag='0. origin image',
img=im, img=im,
step=step) step=step)
op_id = 1 op_id = 1
for op in self.transforms: for op in self.transforms:
if isinstance(op, SegTransform): if isinstance(op, SegTransform):
...@@ -112,11 +115,11 @@ class Compose(SegTransform): ...@@ -112,11 +115,11 @@ class Compose(SegTransform):
outputs = (im, im_info, label) outputs = (im, im_info, label)
else: else:
outputs = (im, im_info) outputs = (im, im_info)
if images_writer is not None: if vdl_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__ tag = str(op_id) + '. ' + op.__class__.__name__
images_writer.add_image(tag=tag, vdl_writer.add_image(tag=tag,
img=im, img=im,
step=step) step=step)
op_id += 1 op_id += 1
return outputs return outputs
......
...@@ -19,6 +19,18 @@ from .det_transforms import DetTransform ...@@ -19,6 +19,18 @@ from .det_transforms import DetTransform
from .seg_transforms import SegTransform from .seg_transforms import SegTransform
def visualize(dataset, index=0, steps=3, save_dir='vdl_output'): def visualize(dataset, index=0, steps=3, save_dir='vdl_output'):
'''对数据预处理/增强中间结果进行可视化。
可使用VisualDL查看中间结果:
1. VisualDL启动方式: visualdl --logdir vdl_output --port 8001
2. 浏览器打开 https://0.0.0.0:8001即可,
其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
Args:
dataset (paddlex.datasets): 数据集读取器。
index (int): 对数据集中的第index张图像进行可视化。默认为0
steps (int): 数据预处理/增强的次数。默认为3。
save_dir (str): 日志保存的路径。默认为'vdl_output'。
'''
transforms = dataset.transforms transforms = dataset.transforms
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
...@@ -29,8 +41,8 @@ def visualize(dataset, index=0, steps=3, save_dir='vdl_output'): ...@@ -29,8 +41,8 @@ def visualize(dataset, index=0, steps=3, save_dir='vdl_output'):
break break
from visualdl import LogWriter from visualdl import LogWriter
vdl_save_dir = osp.join(save_dir, 'image_transforms') vdl_save_dir = osp.join(save_dir, 'image_transforms')
images_writer = LogWriter(vdl_save_dir) vdl_writer = LogWriter(vdl_save_dir)
data.append(images_writer) data.append(vdl_writer)
for s in range(steps): for s in range(steps):
if s != 0: if s != 0:
data.pop() data.pop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册