提交 de059b20 编写于 作者: S sunyanfang01

modify

上级 7c234f7a
......@@ -48,6 +48,7 @@ if hub.version.hub_version < '1.6.2':
env_info = get_environ_info()
load_model = cv.models.load_model
datasets = cv.datasets
transforms = cv.transforms
log_level = 2
......
......@@ -15,3 +15,5 @@
from . import cls_transforms
from . import det_transforms
from . import seg_transforms
from . import visualize
visualize = visualize.visualize
......@@ -58,24 +58,8 @@ class Compose(ClsTransform):
raise Exception(
"Elements in transforms should be defined in 'paddlex.cls.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
self.images_writer = None
def set_vdl(self, vdl_save_dir=None):
# 对数据预处理结果在VisualDL中可视化
self.images_writer = None
if vdl_save_dir is not None:
if not osp.isdir(vdl_save_dir):
if osp.exists(vdl_save_dir):
os.remove(vdl_save_dir)
os.makedirs(vdl_save_dir)
from visualdl import LogWriter
vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
self.images_writer = LogWriter(vdl_images_dir)
def release_vdl(self):
self.images_writer = None
def __call__(self, im, label=None, step=0):
def __call__(self, im, label=None, images_writer=None, step=0):
"""
Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
......@@ -84,7 +68,6 @@ class Compose(ClsTransform):
tuple: 根据网络所需字段所组成的tuple;
字段由transforms中的最后一个数据预处理操作决定。
"""
im_file = str(step)
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
......@@ -92,15 +75,14 @@ class Compose(ClsTransform):
format(len(im.shape)))
else:
try:
im_file = im
im = cv2.imread(im).astype('float32')
except:
raise TypeError('Can\'t read The image file {}!'.format(im))
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if self.images_writer is not None:
self.images_writer.add_image(tag='0. origin image',
img=im,
step=step)
if images_writer is not None:
images_writer.add_image(tag='0. origin image',
img=im,
step=step)
op_id = 1
for op in self.transforms:
if isinstance(op, ClsTransform):
......@@ -115,9 +97,9 @@ class Compose(ClsTransform):
outputs = (im, )
if label is not None:
outputs = (im, label)
if self.images_writer is not None:
if images_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__
self.images_writer.add_image(tag=tag,
images_writer.add_image(tag=tag,
img=im,
step=step)
op_id += 1
......
......@@ -51,7 +51,7 @@ class Compose(DetTransform):
ValueError: 数据长度不匹配。
"""
def __init__(self, transforms, vdl_save_dir=None):
def __init__(self, transforms):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
......@@ -70,24 +70,8 @@ class Compose(DetTransform):
raise Exception(
"Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
self.images_writer = None
def set_vdl(self, vdl_save_dir=None):
# 对数据预处理结果在VisualDL中可视化
self.images_writer = None
if vdl_save_dir is not None:
if not osp.isdir(vdl_save_dir):
if osp.exists(vdl_save_dir):
os.remove(vdl_save_dir)
os.makedirs(vdl_save_dir)
from visualdl import LogWriter
vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
self.images_writer = LogWriter(vdl_images_dir)
def release_vdl(self):
self.images_writer = None
def __call__(self, im, im_info=None, label_info=None, step=0):
def __call__(self, im, im_info=None, label_info=None, images_writer=None, step=0):
"""
Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
......@@ -151,19 +135,15 @@ class Compose(DetTransform):
else:
return (im, im_info, label_info)
if isinstance(im, str):
im_file = im
else:
im_file = str(step)
outputs = decode_image(im, im_info, label_info)
im = outputs[0]
im_info = outputs[1]
if len(outputs) == 3:
label_info = outputs[2]
if self.images_writer is not None:
self.images_writer.add_image(tag='0. origin image',
img=im,
step=step)
if images_writer is not None:
images_writer.add_image(tag='0. origin image',
img=im,
step=step)
op_id = 1
for op in self.transforms:
if im is None:
......@@ -177,9 +157,9 @@ class Compose(DetTransform):
outputs = (im, im_info, label_info)
else:
outputs = (im, im_info)
if self.images_writer is not None:
if images_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__
self.images_writer.add_image(tag=tag,
images_writer.add_image(tag=tag,
img=im,
step=step)
op_id += 1
......
......@@ -46,7 +46,7 @@ class Compose(SegTransform):
"""
def __init__(self, transforms, vdl_save_dir=None):
def __init__(self, transforms):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
......@@ -62,24 +62,8 @@ class Compose(SegTransform):
raise Exception(
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
)
self.images_writer = None
def set_vdl(self, vdl_save_dir=None):
# 对数据预处理结果在VisualDL中可视化
self.images_writer = None
if vdl_save_dir is not None:
if not osp.isdir(vdl_save_dir):
if osp.exists(vdl_save_dir):
os.remove(vdl_save_dir)
os.makedirs(vdl_save_dir)
from visualdl import LogWriter
vdl_images_dir = osp.join(vdl_save_dir, 'image_transforms')
self.images_writer = LogWriter(vdl_images_dir)
def release_vdl(self):
self.images_writer = None
def __call__(self, im, im_info=None, label=None, step=0):
def __call__(self, im, im_info=None, label=None, images_writer=None, step=0):
"""
Args:
im (str/np.ndarray): 图像路径/图像np.ndarray数据。
......@@ -92,7 +76,6 @@ class Compose(SegTransform):
Returns:
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
"""
im_file = str(step)
if im_info is None:
im_info = list()
if isinstance(im, np.ndarray):
......@@ -102,7 +85,6 @@ class Compose(SegTransform):
format(len(im.shape)))
else:
try:
im_file = im
im = cv2.imread(im).astype('float32')
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
......@@ -111,10 +93,10 @@ class Compose(SegTransform):
if label is not None:
if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label))
if self.images_writer is not None:
self.images_writer.add_image(tag='0. origin image',
img=im,
step=step)
if images_writer is not None:
images_writer.add_image(tag='0. origin image',
img=im,
step=step)
op_id = 1
for op in self.transforms:
if isinstance(op, SegTransform):
......@@ -130,9 +112,9 @@ class Compose(SegTransform):
outputs = (im, im_info, label)
else:
outputs = (im, im_info)
if self.images_writer is not None:
if images_writer is not None:
tag = str(op_id) + '. ' + op.__class__.__name__
self.images_writer.add_image(tag=tag,
images_writer.add_image(tag=tag,
img=im,
step=step)
op_id += 1
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
from .cls_transforms import ClsTransform
from .det_transforms import DetTransform
from .seg_transforms import SegTransform
def visualize(dataset, index=0, steps=3, save_dir='vdl_output'):
transforms = dataset.transforms
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
for i, data in enumerate(dataset.iterator()):
if i == index:
break
from visualdl import LogWriter
vdl_save_dir = osp.join(save_dir, 'image_transforms')
images_writer = LogWriter(vdl_save_dir)
data.append(images_writer)
for s in range(steps):
if s != 0:
data.pop()
data.append(s)
transforms(*data)
\ No newline at end of file
......@@ -34,19 +34,6 @@ eval_dataset = pdx.datasets.ImageNet(
label_list='vegetables_cls/labels.txt',
transforms=eval_transforms)
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.set_vdl(vdl_save_dir='vdl_output')
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.release_vdl()
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/mobilenetv2/vdl_log --port 8001
......
......@@ -38,18 +38,6 @@ eval_dataset = pdx.datasets.VOCDetection(
label_list='insect_det/labels.txt',
transforms=eval_transforms)
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.set_vdl(vdl_save_dir='vdl_output')
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.release_vdl()
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/yolov3_darknet/vdl_log --port 8001
......
......@@ -33,18 +33,6 @@ eval_dataset = pdx.datasets.SegDataset(
label_list='optic_disc_seg/labels.txt',
transforms=eval_transforms)
# 可使用VisualDL查看数据预处理的中间结果
# VisualDL启动方式: visualdl --logdir vdl_output --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
train_transforms.vdl_save_dir = 'vdl_output'
for step, data in enumerate(train_dataset.iterator()):
data.append(step)
train_transforms(*data)
if step == 5:
break
train_transforms.vdl_save_dir = None
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/deeplab/vdl_log --port 8001
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册