提交 564417a4 编写于 作者: F FlyingQianMM

add multichannel RemoteSensing

上级 38980c5c
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import argparse
from paddlex.seg import transforms
import paddlex.RemoteSensing.transforms as custom_transforms
import paddlex as pdx
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing training')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='dataset directory',
default=None,
type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='model save directory',
default=None,
type=str)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
default=None,
type=int)
parser.add_argument(
'--channel',
dest='channel',
help='number of data channel',
default=3,
type=int)
parser.add_argument(
'--clip_min_value',
dest='clip_min_value',
help='Min values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--clip_max_value',
dest='clip_max_value',
help='Max values for clipping data',
nargs='+',
default=None,
type=int)
parser.add_argument(
'--mean',
dest='mean',
help='Data means',
nargs='+',
default=None,
type=float)
parser.add_argument(
'--std',
dest='std',
help='Data standard deviation',
nargs='+',
default=None,
type=float)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='number of traing epochs',
default=100,
type=int)
parser.add_argument(
'--train_batch_size',
dest='train_batch_size',
help='training batch size',
default=4,
type=int)
parser.add_argument(
'--lr', dest='lr', help='learning rate', default=0.01, type=float)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
save_dir = args.save_dir
num_classes = args.num_classes
channel = args.channel
clip_min_value = args.clip_min_value
clip_max_value = args.clip_max_value
mean = args.mean
std = args.std
num_epochs = args.num_epochs
train_batch_size = args.train_batch_size
lr = args.lr
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.RandomVerticalFlip(0.5),
transforms.RandomHorizontalFlip(0.5),
transforms.ResizeStepScaling(0.5, 2.0, 0.25),
transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
custom_transforms.Clip(
min_val=clip_min_value, max_val=clip_max_value),
transforms.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
train_transforms.decode_image = custom_transforms.decode_image
eval_transforms = transforms.Compose([
custom_transforms.Clip(
min_val=clip_min_value, max_val=clip_max_value),
transforms.Normalize(
min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
])
eval_transforms.decode_image = custom_transforms.decode_image
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
label_list = osp.join(data_dir, 'labels.txt')
train_dataset = pdx.datasets.SegDataset(
data_dir=data_dir,
file_list=train_list,
label_list=label_list,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir=data_dir,
file_list=val_list,
label_list=label_list,
transforms=eval_transforms)
model = pdx.seg.UNet(num_classes=num_classes, input_channel=channel)
model.train(
num_epochs=num_epochs,
train_dataset=train_dataset,
train_batch_size=train_batch_size,
eval_dataset=eval_dataset,
save_interval_epochs=5,
log_interval_steps=10,
save_dir=save_dir,
learning_rate=lr,
use_vdl=True)
import os
import os.path as osp
import imghdr
import gdal
import numpy as np
from PIL import Image
from paddlex.seg import transforms
def read_img(img_path):
img_format = imghdr.what(img_path)
name, ext = osp.splitext(img_path)
if img_format == 'tiff' or ext == '.img':
dataset = gdal.Open(img_path)
if dataset == None:
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
return im_data.transpose((1, 2, 0))
elif img_format == 'png':
return np.asarray(Image.open(img_path))
elif ext == '.npy':
return np.load(img_path)
else:
raise Exception('Image format {} is not supported!'.format(ext))
def decode_image(im, label):
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimensions, but now is {}-dimensions".format(
len(im.shape)))
else:
try:
im = read_img(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
if label is not None:
if not isinstance(label, np.ndarray):
label = read_img(label)
return (im, label)
class Clip(transforms.SegTransform):
"""
对图像上超出一定范围的数据进行裁剪。
Args:
min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
"""
def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
self.min_val = min_val
self.max_val = max_val
if not (isinstance(self.min_val, list) and isinstance(self.max_val,
list)):
raise ValueError("{}: input type is invalid.".format(self))
def __call__(self, im, im_info=None, label=None):
for k in range(im.shape[2]):
np.clip(
im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
...@@ -32,6 +32,7 @@ from . import slim ...@@ -32,6 +32,7 @@ from . import slim
from . import convertor from . import convertor
from . import tools from . import tools
from . import deploy from . import deploy
from . import RemoteSensing
try: try:
import pycocotools import pycocotools
......
...@@ -67,8 +67,6 @@ class SegDataset(Dataset): ...@@ -67,8 +67,6 @@ class SegDataset(Dataset):
items = line.strip().split() items = line.strip().split()
items[0] = path_normalization(items[0]) items[0] = path_normalization(items[0])
items[1] = path_normalization(items[1]) items[1] = path_normalization(items[1])
if not is_pic(items[0]):
continue
full_path_im = osp.join(data_dir, items[0]) full_path_im = osp.join(data_dir, items[0])
full_path_label = osp.join(data_dir, items[1]) full_path_label = osp.join(data_dir, items[1])
if not osp.exists(full_path_im): if not osp.exists(full_path_im):
......
...@@ -43,6 +43,7 @@ class UNet(DeepLabv3p): ...@@ -43,6 +43,7 @@ class UNet(DeepLabv3p):
def __init__(self, def __init__(self,
num_classes=2, num_classes=2,
input_channel=3,
upsample_mode='bilinear', upsample_mode='bilinear',
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
...@@ -71,6 +72,7 @@ class UNet(DeepLabv3p): ...@@ -71,6 +72,7 @@ class UNet(DeepLabv3p):
'Expect class_weight is a list or string but receive {}'. 'Expect class_weight is a list or string but receive {}'.
format(type(class_weight))) format(type(class_weight)))
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.upsample_mode = upsample_mode self.upsample_mode = upsample_mode
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
...@@ -82,6 +84,7 @@ class UNet(DeepLabv3p): ...@@ -82,6 +84,7 @@ class UNet(DeepLabv3p):
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.UNet( model = paddlex.cv.nets.segmentation.UNet(
self.num_classes, self.num_classes,
input_channel=self.input_channel,
mode=mode, mode=mode,
upsample_mode=self.upsample_mode, upsample_mode=self.upsample_mode,
use_bce_loss=self.use_bce_loss, use_bce_loss=self.use_bce_loss,
......
...@@ -64,6 +64,7 @@ class UNet(object): ...@@ -64,6 +64,7 @@ class UNet(object):
def __init__(self, def __init__(self,
num_classes, num_classes,
input_channel=3,
mode='train', mode='train',
upsample_mode='bilinear', upsample_mode='bilinear',
use_bce_loss=False, use_bce_loss=False,
...@@ -92,6 +93,7 @@ class UNet(object): ...@@ -92,6 +93,7 @@ class UNet(object):
'Expect class_weight is a list or string but receive {}'. 'Expect class_weight is a list or string but receive {}'.
format(type(class_weight))) format(type(class_weight)))
self.num_classes = num_classes self.num_classes = num_classes
self.input_channel = input_channel
self.mode = mode self.mode = mode
self.upsample_mode = upsample_mode self.upsample_mode = upsample_mode
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
...@@ -232,13 +234,16 @@ class UNet(object): ...@@ -232,13 +234,16 @@ class UNet(object):
if self.fixed_input_shape is not None: if self.fixed_input_shape is not None:
input_shape = [ input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] None, self.input_channel, self.fixed_input_shape[1],
self.fixed_input_shape[0]
] ]
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image') dtype='float32', shape=input_shape, name='image')
else: else:
inputs['image'] = fluid.data( inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image') dtype='float32',
shape=[None, self.input_channel, None, None],
name='image')
if self.mode == 'train': if self.mode == 'train':
inputs['label'] = fluid.data( inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label') dtype='int32', shape=[None, 1, None, None], name='label')
......
...@@ -18,8 +18,12 @@ import numpy as np ...@@ -18,8 +18,12 @@ import numpy as np
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
def normalize(im, mean, std): def normalize(im, mean, std, min_value, max_value):
im = im / 255.0 # Rescaling (min-max normalization)
range_value = [max_value[i] - min_value[i] for i in range(len(max_value))]
im = (im - min_value) / range_value
# Standardization (Z-score Normalization)
im -= mean im -= mean
im /= std im /= std
return im return im
......
...@@ -60,6 +60,24 @@ class Compose(SegTransform): ...@@ -60,6 +60,24 @@ class Compose(SegTransform):
"Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/" "Elements in transforms should be defined in 'paddlex.seg.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
) )
@staticmethod
def decode_image(im, label):
if isinstance(im, np.ndarray):
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimensions, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if label is not None:
if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label))
return (im, label)
def __call__(self, im, im_info=None, label=None): def __call__(self, im, im_info=None, label=None):
""" """
Args: Args:
...@@ -73,24 +91,12 @@ class Compose(SegTransform): ...@@ -73,24 +91,12 @@ class Compose(SegTransform):
tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。
""" """
if isinstance(im, np.ndarray): im, label = self.decode_image(im, label)
if len(im.shape) != 3:
raise Exception(
"im should be 3-dimensions, but now is {}-dimensions".
format(len(im.shape)))
else:
try:
im = cv2.imread(im)
except:
raise ValueError('Can\'t read The image file {}!'.format(im))
im = im.astype('float32')
if im_info is None:
im_info = [('origin_shape', im.shape[0:2])]
if self.to_rgb: if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if im_info is None:
im_info = [('origin_shape', im.shape[0:2])]
if label is not None: if label is not None:
if not isinstance(label, np.ndarray):
label = np.asarray(Image.open(label))
origin_label = label.copy() origin_label = label.copy()
for op in self.transforms: for op in self.transforms:
if isinstance(op, SegTransform): if isinstance(op, SegTransform):
...@@ -561,11 +567,21 @@ class Normalize(SegTransform): ...@@ -561,11 +567,21 @@ class Normalize(SegTransform):
ValueError: mean或std不是list对象。std包含0。 ValueError: mean或std不是list对象。std包含0。
""" """
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): def __init__(self,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
min_val=[0, 0, 0],
max_val=[255.0, 255.0, 255.0]):
self.min_val = min_val
self.max_val = max_val
self.mean = mean self.mean = mean
self.std = std self.std = std
if not (isinstance(self.mean, list) and isinstance(self.std, list)): if not (isinstance(self.mean, list) and isinstance(self.std, list)):
raise ValueError("{}: input type is invalid.".format(self)) raise ValueError("{}: input type is invalid.".format(self))
if not (isinstance(self.min_val, list) and isinstance(self.max_val,
list)):
raise ValueError("{}: input type is invalid.".format(self))
from functools import reduce from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0: if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self)) raise ValueError('{}: std is invalid!'.format(self))
...@@ -588,7 +604,7 @@ class Normalize(SegTransform): ...@@ -588,7 +604,7 @@ class Normalize(SegTransform):
mean = np.array(self.mean)[np.newaxis, np.newaxis, :] mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std) im = normalize(im, mean, std, self.min_val, self.max_val)
if label is None: if label is None:
return (im, im_info) return (im, im_info)
...@@ -752,23 +768,26 @@ class RandomPaddingCrop(SegTransform): ...@@ -752,23 +768,26 @@ class RandomPaddingCrop(SegTransform):
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder( img_channel = im.shape[2]
im, import copy
0, orig_im = copy.deepcopy(im)
pad_height, im = np.zeros((img_height + pad_height, img_width + pad_width,
0, img_channel)).astype(orig_im.dtype)
pad_width, for i in range(img_channel):
cv2.BORDER_CONSTANT, im[:, :, i] = np.pad(
value=self.im_padding_value) orig_im[:, :, i],
pad_width=((0, pad_height), (0, pad_width)),
mode='constant',
constant_values=(self.im_padding_value[i],
self.im_padding_value[i]))
if label is not None: if label is not None:
label = cv2.copyMakeBorder( label = np.pad(label,
label, pad_width=((0, pad_height), (0, pad_width)),
0, mode='constant',
pad_height, constant_values=(self.label_padding_value,
0, self.label_padding_value))
pad_width,
cv2.BORDER_CONSTANT,
value=self.label_padding_value)
img_height = im.shape[0] img_height = im.shape[0]
img_width = im.shape[1] img_width = im.shape[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册