提交 4cbf8369 编写于 作者: S sunyanfang01

add transforms vdl

上级 bc44ce9d
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug from .imgaug_support import execute_imgaug
import random import random
import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
...@@ -57,8 +58,24 @@ class Compose(ClsTransform): ...@@ -57,8 +58,24 @@ class Compose(ClsTransform):
raise Exception( raise Exception(
"Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
self.images_writer = None
def __call__(self, im, label=None): def set_vdl(self, vdl_save_dir=None):
# 对数据预处理结果在VisualDL中可视化
self.images_writer = None
if vdl_save_dir is not None:
if not osp.isdir(vdl_save_dir):
if osp.exists(vdl_save_dir):
os.remove(vdl_save_dir)
os.makedirs(vdl_save_dir)
from visualdl import LogWriter
vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
self.images_writer = LogWriter(vdl_images_dir)
def release_vdl(self):
self.images_writer = None
def __call__(self, im, label=None, step=0):
""" """
Args: Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。 im (str/np.ndarray): 图像路径/图像np.ndarray数据。
...@@ -67,6 +84,7 @@ class Compose(ClsTransform): ...@@ -67,6 +84,7 @@ class Compose(ClsTransform):
tuple: 根据网络所需字段所组成的tuple; tuple: 根据网络所需字段所组成的tuple;
字段由transforms中的最后一个数据预处理操作决定。 字段由transforms中的最后一个数据预处理操作决定。
""" """
im_file = str(step)
if isinstance(im, np.ndarray): if isinstance(im, np.ndarray):
if len(im.shape) != 3: if len(im.shape) != 3:
raise Exception( raise Exception(
...@@ -74,10 +92,16 @@ class Compose(ClsTransform): ...@@ -74,10 +92,16 @@ class Compose(ClsTransform):
format(len(im.shape))) format(len(im.shape)))
else: else:
try: try:
im_file = im
im = cv2.imread(im).astype('float32') im = cv2.imread(im).astype('float32')
except: except:
raise TypeError('Can\'t read The image file {}!'.format(im)) raise TypeError('Can\'t read The image file {}!'.format(im))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.images_writer is not None:
self.images_writer.add_image(tag='0. origin image',
img=im,
step=step)
op_id = 1
for op in self.transforms: for op in self.transforms:
if isinstance(op, ClsTransform): if isinstance(op, ClsTransform):
outputs = op(im, label) outputs = op(im, label)
...@@ -91,6 +115,12 @@ class Compose(ClsTransform): ...@@ -91,6 +115,12 @@ class Compose(ClsTransform):
outputs = (im, ) outputs = (im, )
if label is not None: if label is not None:
outputs = (im, label) outputs = (im, label)
if self.images_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__
self.images_writer.add_image(tag=tag,
img=im,
step=step)
op_id += 1
return outputs return outputs
def add_augmenters(self, augmenters): def add_augmenters(self, augmenters):
...@@ -434,6 +464,7 @@ class RandomDistort(ClsTransform): ...@@ -434,6 +464,7 @@ class RandomDistort(ClsTransform):
params['im'] = im params['im'] = im
if np.random.uniform(0, 1) < prob: if np.random.uniform(0, 1) < prob:
im = ops[id](**params) im = ops[id](**params)
im = im.astype('float32')
if label is None: if label is None:
return (im, ) return (im, )
else: else:
......
...@@ -18,6 +18,7 @@ except Exception: ...@@ -18,6 +18,7 @@ except Exception:
from collections import Sequence from collections import Sequence
import random import random
import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
...@@ -50,7 +51,7 @@ class Compose(DetTransform): ...@@ -50,7 +51,7 @@ class Compose(DetTransform):
ValueError: 数据长度不匹配。 ValueError: 数据长度不匹配。
""" """
def __init__(self, transforms): def __init__(self, transforms, vdl_save_dir=None):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
...@@ -69,8 +70,24 @@ class Compose(DetTransform): ...@@ -69,8 +70,24 @@ class Compose(DetTransform):
raise Exception( raise Exception(
"Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
self.images_writer = None
def __call__(self, im, im_info=None, label_info=None):
def set_vdl(self, vdl_save_dir=None):
# 对数据预处理结果在VisualDL中可视化
self.images_writer = None
if vdl_save_dir is not None:
if not osp.isdir(vdl_save_dir):
if osp.exists(vdl_save_dir):
os.remove(vdl_save_dir)
os.makedirs(vdl_save_dir)
from visualdl import LogWriter
vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
self.images_writer = LogWriter(vdl_images_dir)
def release_vdl(self):
self.images_writer = None
def __call__(self, im, im_info=None, label_info=None, step=0):
""" """
Args: Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。 im (str/np.ndarray): 图像路径/图像np.ndarray数据。
...@@ -134,11 +151,20 @@ class Compose(DetTransform): ...@@ -134,11 +151,20 @@ class Compose(DetTransform):
else: else:
return (im, im_info, label_info) return (im, im_info, label_info)
if isinstance(im, str):
im_file = im
else:
im_file = str(step)
outputs = decode_image(im, im_info, label_info) outputs = decode_image(im, im_info, label_info)
im = outputs[0] im = outputs[0]
im_info = outputs[1] im_info = outputs[1]
if len(outputs) == 3: if len(outputs) == 3:
label_info = outputs[2] label_info = outputs[2]
if self.images_writer is not None:
self.images_writer.add_image(tag='0. origin image',
img=im,
step=step)
op_id = 1
for op in self.transforms: for op in self.transforms:
if im is None: if im is None:
return None return None
...@@ -151,6 +177,12 @@ class Compose(DetTransform): ...@@ -151,6 +177,12 @@ class Compose(DetTransform):
outputs = (im, im_info, label_info) outputs = (im, im_info, label_info)
else: else:
outputs = (im, im_info) outputs = (im, im_info)
if self.images_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__
self.images_writer.add_image(tag=tag,
img=im,
step=step)
op_id += 1
return outputs return outputs
def add_augmenters(self, augmenters): def add_augmenters(self, augmenters):
...@@ -621,6 +653,7 @@ class RandomDistort(DetTransform): ...@@ -621,6 +653,7 @@ class RandomDistort(DetTransform):
if np.random.uniform(0, 1) < prob: if np.random.uniform(0, 1) < prob:
im = ops[id](**params) im = ops[id](**params)
im = im.astype('float32')
if label_info is None: if label_info is None:
return (im, im_info) return (im, im_info)
else: else:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from .ops import * from .ops import *
from .imgaug_support import execute_imgaug from .imgaug_support import execute_imgaug
import random import random
...@@ -45,7 +46,7 @@ class Compose(SegTransform): ...@@ -45,7 +46,7 @@ class Compose(SegTransform):
""" """
def __init__(self, transforms): def __init__(self, transforms, vdl_save_dir=None):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
...@@ -61,8 +62,24 @@ class Compose(SegTransform): ...@@ -61,8 +62,24 @@ class Compose(SegTransform):
raise Exception( raise Exception(
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
self.images_writer = None
def __call__(self, im, im_info=None, label=None):
def set_vdl(self, vdl_save_dir=None):
# 对数据预处理结果在VisualDL中可视化
self.images_writer = None
if vdl_save_dir is not None:
if not osp.isdir(vdl_save_dir):
if osp.exists(vdl_save_dir):
os.remove(vdl_save_dir)
os.makedirs(vdl_save_dir)
from visualdl import LogWriter
vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
self.images_writer = LogWriter(vdl_images_dir)
def release_vdl(self):
self.images_writer = None
def __call__(self, im, im_info=None, label=None, step=0):
""" """
Args: Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。 im (str/np.ndarray): 图像路径/图像np.ndarray数据。
...@@ -75,7 +92,7 @@ class Compose(SegTransform): ...@@ -75,7 +92,7 @@ class Compose(SegTransform):
Returns: Returns:
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
""" """
im_file = str(step)
if im_info is None: if im_info is None:
im_info = list() im_info = list()
if isinstance(im, np.ndarray): if isinstance(im, np.ndarray):
...@@ -85,6 +102,7 @@ class Compose(SegTransform): ...@@ -85,6 +102,7 @@ class Compose(SegTransform):
format(len(im.shape))) format(len(im.shape)))
else: else:
try: try:
im_file = im
im = cv2.imread(im).astype('float32') im = cv2.imread(im).astype('float32')
except: except:
raise ValueError('Can\'t read The image file {}!'.format(im)) raise ValueError('Can\'t read The image file {}!'.format(im))
...@@ -93,6 +111,11 @@ class Compose(SegTransform): ...@@ -93,6 +111,11 @@ class Compose(SegTransform):
if label is not None: if label is not None:
if not isinstance(label, np.ndarray): if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label)) label = np.asarray(Image.open(label))
if self.images_writer is not None:
self.images_writer.add_image(tag='0. origin image',
img=im,
step=step)
op_id = 1
for op in self.transforms: for op in self.transforms:
if isinstance(op, SegTransform): if isinstance(op, SegTransform):
outputs = op(im, im_info, label) outputs = op(im, im_info, label)
...@@ -107,6 +130,12 @@ class Compose(SegTransform): ...@@ -107,6 +130,12 @@ class Compose(SegTransform):
outputs = (im, im_info, label) outputs = (im, im_info, label)
else: else:
outputs = (im, im_info) outputs = (im, im_info)
if self.images_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__
self.images_writer.add_image(tag=tag,
img=im,
step=step)
op_id += 1
return outputs return outputs
def add_augmenters(self, augmenters): def add_augmenters(self, augmenters):
...@@ -1053,6 +1082,7 @@ class RandomDistort(SegTransform): ...@@ -1053,6 +1082,7 @@ class RandomDistort(SegTransform):
params['im'] = im params['im'] = im
if np.random.uniform(0, 1) < prob: if np.random.uniform(0, 1) < prob:
im = ops[id](**params) im = ops[id](**params)
im = im.astype('float32')
if label is None: if label is None:
return (im, im_info) return (im, im_info)
else: else:
......
...@@ -34,6 +34,19 @@ eval_dataset = pdx.datasets.ImageNet( ...@@ -34,6 +34,19 @@ eval_dataset = pdx.datasets.ImageNet(
label_list='vegetables_cls/labels.txt', label_list='vegetables_cls/labels.txt',
transforms=eval_transforms) transforms=eval_transforms)
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.set_vdl(vdl_save_dir='vdl_output')
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.release_vdl()
# 初始化模型,并进行训练 # 初始化模型,并进行训练
# 可使用VisualDL查看训练指标 # 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001 # VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001
......
...@@ -38,6 +38,30 @@ eval_dataset = pdx.datasets.VOCDetection( ...@@ -38,6 +38,30 @@ eval_dataset = pdx.datasets.VOCDetection(
label_list='insect_det/labels.txt', label_list='insect_det/labels.txt',
transforms=eval_transforms) transforms=eval_transforms)
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.set_vdl(vdl_save_dir='vdl_output')
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.release_vdl()
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.vdl_save_dir = 'vdl_output'
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.vdl_save_dir = None
# 初始化模型,并进行训练 # 初始化模型,并进行训练
# 可使用VisualDL查看训练指标 # 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001 # VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001
......
...@@ -33,6 +33,18 @@ eval_dataset = pdx.datasets.SegDataset( ...@@ -33,6 +33,18 @@ eval_dataset = pdx.datasets.SegDataset(
label_list='optic_disc_seg/labels.txt', label_list='optic_disc_seg/labels.txt',
transforms=eval_transforms) transforms=eval_transforms)
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.vdl_save_dir = 'vdl_output'
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.vdl_save_dir = None
# 初始化模型,并进行训练 # 初始化模型,并进行训练
# 可使用VisualDL查看训练指标 # 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001 # VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册