提交 b4e23302 编写于 作者: C chenguowei01

rm ArrageSegmeter

上级 578f83f0
...@@ -111,7 +111,7 @@ def infer(model, data_dir=None, test_list=None, model_dir=None, ...@@ -111,7 +111,7 @@ def infer(model, data_dir=None, test_list=None, model_dir=None,
for file in tqdm.tqdm(files): for file in tqdm.tqdm(files):
file = file.strip() file = file.strip()
im_file = osp.join(data_dir, file) im_file = osp.join(data_dir, file)
im, im_info = transforms(im_file) im, im_info, _ = transforms(im_file)
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
im = to_variable(im) im = to_variable(im)
...@@ -140,17 +140,8 @@ def infer(model, data_dir=None, test_list=None, model_dir=None, ...@@ -140,17 +140,8 @@ def infer(model, data_dir=None, test_list=None, model_dir=None,
cv2.imwrite(pred_saved_path, pred_im) cv2.imwrite(pred_saved_path, pred_im)
def arrange_transform(transforms, mode='train'):
arrange_transform = T.ArrangeSegmenter
if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
transforms.transforms[-1] = arrange_transform(mode=mode)
else:
transforms.transforms.append(arrange_transform(mode=mode))
def main(args): def main(args):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
arrange_transform(test_transforms, mode='test')
if args.model_name == 'UNet': if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes) model = models.UNet(num_classes=args.num_classes)
......
...@@ -143,7 +143,7 @@ def train(model, ...@@ -143,7 +143,7 @@ def train(model,
for epoch in range(num_epochs): for epoch in range(num_epochs):
for step, data in enumerate(data_generator()): for step, data in enumerate(data_generator()):
images = np.array([d[0] for d in data]) images = np.array([d[0] for d in data])
labels = np.array([d[1] for d in data]).astype('int64') labels = np.array([d[2] for d in data]).astype('int64')
images = to_variable(images) images = to_variable(images)
labels = to_variable(labels) labels = to_variable(labels)
loss = model(images, labels, mode='train') loss = model(images, labels, mode='train')
...@@ -175,21 +175,12 @@ def train(model, ...@@ -175,21 +175,12 @@ def train(model,
model.train() model.train()
def arrange_transform(transforms, mode='train'):
arrange_transform = T.ArrangeSegmenter
if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
transforms.transforms[-1] = arrange_transform(mode=mode)
else:
transforms.transforms.append(arrange_transform(mode=mode))
def main(args): def main(args):
# Creat dataset reader # Creat dataset reader
train_transforms = T.Compose( train_transforms = T.Compose(
[T.Resize(args.input_size), [T.Resize(args.input_size),
T.RandomHorizontalFlip(), T.RandomHorizontalFlip(),
T.Normalize()]) T.Normalize()])
arrange_transform(train_transforms, mode='train')
train_dataset = Dataset( train_dataset = Dataset(
data_dir=args.data_dir, data_dir=args.data_dir,
file_list=args.train_list, file_list=args.train_list,
...@@ -200,7 +191,6 @@ def main(args): ...@@ -200,7 +191,6 @@ def main(args):
shuffle=True) shuffle=True)
if args.val_list is not None: if args.val_list is not None:
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
arrange_transform(eval_transforms, mode='eval')
eval_dataset = Dataset( eval_dataset = Dataset(
data_dir=args.data_dir, data_dir=args.data_dir,
file_list=args.val_list, file_list=args.val_list,
......
...@@ -74,7 +74,10 @@ class Compose: ...@@ -74,7 +74,10 @@ class Compose:
im_info = outputs[1] im_info = outputs[1]
if len(outputs) == 3: if len(outputs) == 3:
label = outputs[2] label = outputs[2]
return outputs im = permute(im)
if len(outputs) == 3:
label = label[np.newaxis, :, :]
return (im, im_info, label)
class RandomHorizontalFlip: class RandomHorizontalFlip:
...@@ -873,42 +876,3 @@ class RandomDistort: ...@@ -873,42 +876,3 @@ class RandomDistort:
return (im, im_info) return (im, im_info)
else: else:
return (im, im_info, label) return (im, im_info, label)
class ArrangeSegmenter:
"""获取训练/验证/预测所需的信息。
Args:
mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
Raises:
ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内
"""
def __init__(self, mode):
if mode not in ['train', 'eval', 'test', 'quant']:
raise ValueError(
"mode should be defined as one of ['train', 'eval', 'test', 'quant']!"
)
self.mode = mode
def __call__(self, im, im_info, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict): 存储与图像相关的信息。
label (np.ndarray): 标注图像np.ndarray数据。
Returns:
tuple: 当mode为'train'或'eval'时,返回的tuple为(im, label),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当mode为'test'时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;当mode为
'quant'时,返回的tuple为(im,),为图像np.ndarray数据。
"""
im = permute(im)
if self.mode == 'train' or self.mode == 'eval':
label = label[np.newaxis, :, :]
return (im, label)
elif self.mode == 'test':
return (im, im_info)
else:
return (im, )
...@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.base import to_variable
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from datasets.dataset import Dataset from datasets import Dataset
import transforms as T import transforms as T
import models import models
import utils.logging as logging import utils.logging as logging
...@@ -112,7 +112,7 @@ def evaluate(model, ...@@ -112,7 +112,7 @@ def evaluate(model,
eval_dataset.num_samples, total_steps)) eval_dataset.num_samples, total_steps))
for step, data in enumerate(data_generator()): for step, data in enumerate(data_generator()):
images = np.array([d[0] for d in data]) images = np.array([d[0] for d in data])
labels = np.array([d[1] for d in data]).astype('int64') labels = np.array([d[2] for d in data]).astype('int64')
images = to_variable(images) images = to_variable(images)
pred, _ = model(images, labels, mode='eval') pred, _ = model(images, labels, mode='eval')
...@@ -134,17 +134,8 @@ def evaluate(model, ...@@ -134,17 +134,8 @@ def evaluate(model,
logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa())) logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa()))
def arrange_transform(transforms, mode='train'):
arrange_transform = T.ArrangeSegmenter
if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
transforms.transforms[-1] = arrange_transform(mode=mode)
else:
transforms.transforms.append(arrange_transform(mode=mode))
def main(args): def main(args):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
arrange_transform(eval_transforms, mode='eval')
eval_dataset = Dataset( eval_dataset = Dataset(
data_dir=args.data_dir, data_dir=args.data_dir,
file_list=args.val_list, file_list=args.val_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册