diff --git a/dygraph/README.md b/dygraph/README.md new file mode 100644 index 0000000000000000000000000000000000000000..20728e6bb5321e78122b9e925544419e9ec4483c --- /dev/null +++ b/dygraph/README.md @@ -0,0 +1,43 @@ +# 动态图执行 + +## 数据集设置 +``` +data_dir='data/path' +train_list='train/list/path' +val_list='val/list/path' +test_list='test/list/path' +num_classes=number/of/dataset/classes +``` + +## 训练 +``` +python3 train.py --model_name UNet \ +--data_dir $data_dir \ +--train_list $train_list \ +--val_list $val_list \ +--num_classes $num_classes \ +--input_size 192 192 \ +--num_epochs 4 \ +--save_interval_epochs 1 \ +--save_dir output +``` + +## 评估 +``` +python3 val.py --model_name UNet \ +--data_dir $data_dir \ +--val_list $val_list \ +--num_classes $num_classes \ +--input_size 192 192 \ +--model_dir output/epoch_1 +``` + +## 预测 +``` +python3 infer.py --model_name UNet \ +--data_dir $data_dir \ +--test_list $test_list \ +--num_classes $num_classes \ +--input_size 192 192 \ +--model_dir output/epoch_1 +``` diff --git a/dygraph/datasets/__init__.py b/dygraph/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..072a82f7409a9369d2c3b1bdba603527eac0bb7f --- /dev/null +++ b/dygraph/datasets/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optic_disc_seg import OpticDiscSeg +from .cityscapes import Cityscapes diff --git a/dygraph/datasets/cityscapes.py b/dygraph/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..21f967820ec32aa37b1877ae7d583eb3e5aac674 --- /dev/null +++ b/dygraph/datasets/cityscapes.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from paddle.fluid.io import Dataset + +from utils.download import download_file_and_uncompress + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') +URL = "https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar" + + +class Cityscapes(Dataset): + def __init__(self, + data_dir=None, + transforms=None, + mode='train', + download=True): + self.data_dir = data_dir + self.transforms = transforms + self.file_list = list() + self.mode = mode + self.num_classes = 19 + + if mode.lower() not in ['train', 'eval', 'test']: + raise Exception( + "mode should be 'train', 'eval' or 'test', but got {}.".format( + mode)) + + if self.transforms is None: + raise Exception("transform is necessary, but it is None.") + + self.data_dir = data_dir + if self.data_dir is None: + if not download: + raise Exception("data_file not set and auto download disabled.") + self.data_dir = download_file_and_uncompress( + url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) + + if mode == 'train': + file_list = os.path.join(self.data_dir, 'train.list') + elif mode == 'eval': + file_list = os.path.join(self.data_dir, 'val.list') + else: + file_list = os.path.join(self.data_dir, 'test.list') + + with open(file_list, 'r') as f: + for line in f: + items = line.strip().split() + if len(items) != 2: + if mode == 'train' or mode == 'eval': + raise Exception( + "File list format incorrect! It should be" + " image_name label_name\\n") + image_path = os.path.join(self.data_dir, items[0]) + grt_path = None + else: + image_path = os.path.join(self.data_dir, items[0]) + grt_path = os.path.join(self.data_dir, items[1]) + self.file_list.append([image_path, grt_path]) + + def __getitem__(self, idx): + image_path, grt_path = self.file_list[idx] + im, im_info, label = self.transforms(im=image_path, label=grt_path) + if self.mode == 'train': + return im, label + elif self.mode == 'eval': + return im, label + if self.mode == 'test': + return im, im_info, image_path + + def __len__(self): + return len(self.file_list) diff --git a/dygraph/datasets/optic_disc_seg.py b/dygraph/datasets/optic_disc_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..0a321915e90c18e99e46d0e53473e695b1ec2317 --- /dev/null +++ b/dygraph/datasets/optic_disc_seg.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from paddle.fluid.io import Dataset + +from utils.download import download_file_and_uncompress + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') +URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" + + +class OpticDiscSeg(Dataset): + def __init__(self, + data_dir=None, + transforms=None, + mode='train', + download=True): + self.data_dir = data_dir + self.transforms = transforms + self.file_list = list() + self.mode = mode + self.num_classes = 2 + + if mode.lower() not in ['train', 'eval', 'test']: + raise Exception( + "mode should be 'train', 'eval' or 'test', but got {}.".format( + mode)) + + if self.transforms is None: + raise Exception("transform is necessary, but it is None.") + + self.data_dir = data_dir + if self.data_dir is None: + if not download: + raise Exception("data_file not set and auto download disabled.") + self.data_dir = download_file_and_uncompress( + url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) + + if mode == 'train': + file_list = os.path.join(self.data_dir, 'train_list.txt') + elif mode == 'eval': + file_list = os.path.join(self.data_dir, 'val_list.txt') + else: + file_list = os.path.join(self.data_dir, 'test_list.txt') + + with open(file_list, 'r') as f: + for line in f: + items = line.strip().split() + if len(items) != 2: + if mode == 'train' or mode == 'eval': + raise Exception( + "File list format incorrect! It should be" + " image_name label_name\\n") + image_path = os.path.join(self.data_dir, items[0]) + grt_path = None + else: + image_path = os.path.join(self.data_dir, items[0]) + grt_path = os.path.join(self.data_dir, items[1]) + self.file_list.append([image_path, grt_path]) + + def __getitem__(self, idx): + image_path, grt_path = self.file_list[idx] + im, im_info, label = self.transforms(im=image_path, label=grt_path) + if self.mode == 'train': + return im, label + elif self.mode == 'eval': + return im, label + if self.mode == 'test': + return im, im_info, image_path + + def __len__(self): + return len(self.file_list) diff --git a/dygraph/infer.py b/dygraph/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..af745a39b025e9e804c207989c939d454d7ff25f --- /dev/null +++ b/dygraph/infer.py @@ -0,0 +1,161 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from paddle.fluid.dygraph.base import to_variable +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.parallel import ParallelEnv +import cv2 +import tqdm + +from datasets import OpticDiscSeg, Cityscapes +import transforms as T +import models +import utils +import utils.logging as logging +from utils import get_environ_info + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model training') + + # params of model + parser.add_argument( + '--model_name', + dest='model_name', + help="Model type for traing, which is one of ('UNet')", + type=str, + default='UNet') + + # params of dataset + parser.add_argument( + '--dataset', + dest='dataset', + help= + "The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", + type=str, + default='OpticDiscSeg') + + # params of prediction + parser.add_argument( + "--input_size", + dest="input_size", + help="The image size for net inputs.", + nargs=2, + default=[512, 512], + type=int) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size', + type=int, + default=2) + parser.add_argument( + '--model_dir', + dest='model_dir', + help='The path of model for evaluation', + type=str, + default=None) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='The directory for saving the inference results', + type=str, + default='./output/result') + + return parser.parse_args() + + +def mkdir(path): + sub_dir = os.path.dirname(path) + if not os.path.exists(sub_dir): + os.makedirs(sub_dir) + + +def infer(model, test_dataset=None, model_dir=None, save_dir='output'): + ckpt_path = os.path.join(model_dir, 'model') + para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path) + model.set_dict(para_state_dict) + model.eval() + + added_saved_dir = os.path.join(save_dir, 'added') + pred_saved_dir = os.path.join(save_dir, 'prediction') + + logging.info("Start to predict...") + for im, im_info, im_path in tqdm.tqdm(test_dataset): + im = im[np.newaxis, ...] + im = to_variable(im) + pred, _ = model(im, mode='test') + pred = pred.numpy() + pred = np.squeeze(pred).astype('uint8') + keys = list(im_info.keys()) + for k in keys[::-1]: + if k == 'shape_before_resize': + h, w = im_info[k][0], im_info[k][1] + pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) + elif k == 'shape_before_padding': + h, w = im_info[k][0], im_info[k][1] + pred = pred[0:h, 0:w] + + im_file = im_path.replace(test_dataset.data_dir, '') + if im_file[0] == '/': + im_file = im_file[1:] + # save added image + added_image = utils.visualize(im_path, pred, weight=0.6) + added_image_path = os.path.join(added_saved_dir, im_file) + mkdir(added_image_path) + cv2.imwrite(added_image_path, added_image) + + # save prediction + pred_im = utils.visualize(im_path, pred, weight=0.0) + pred_saved_path = os.path.join(pred_saved_dir, im_file) + mkdir(pred_saved_path) + cv2.imwrite(pred_saved_path, pred_im) + + +def main(args): + env_info = get_environ_info() + places = fluid.CUDAPlace(ParallelEnv().dev_id) \ + if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ + else fluid.CPUPlace() + + if args.dataset.lower() == 'opticdiscseg': + dataset = OpticDiscSeg + elif args.dataset.lower() == 'cityscapes': + dataset = Cityscapes + else: + raise Exception( + "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')" + ) + + with fluid.dygraph.guard(places): + test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) + test_dataset = dataset(transforms=test_transforms, mode='test') + + if args.model_name == 'UNet': + model = models.UNet(num_classes=test_dataset.num_classes) + + infer( + model, + model_dir=args.model_dir, + test_dataset=test_dataset, + save_dir=args.save_dir) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..792059698bdbc5f95acbd18a0f3cbc6b6ec769e5 --- /dev/null +++ b/dygraph/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .unet import UNet diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..b55e3614b6988a0102eb3e6f17093e59673eae70 --- /dev/null +++ b/dygraph/models/unet.py @@ -0,0 +1,165 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D + + +class UNet(fluid.dygraph.Layer): + def __init__(self, num_classes, ignore_index=255): + super().__init__() + self.encode = UnetEncoder() + self.decode = UnetDecode() + self.get_logit = GetLogit(64, num_classes) + self.ignore_index = ignore_index + self.EPS = 1e-5 + + def forward(self, x, label=None, mode='train'): + encode_data, short_cuts = self.encode(x) + decode_data = self.decode(encode_data, short_cuts) + logit = self.get_logit(decode_data) + if mode == 'train': + return self._get_loss(logit, label) + else: + score_map = fluid.layers.softmax(logit, axis=1) + score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1]) + pred = fluid.layers.argmax(score_map, axis=3) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + return pred, score_map + + def _get_loss(self, logit, label): + mask = label != self.ignore_index + mask = fluid.layers.cast(mask, 'float32') + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + ignore_index=self.ignore_index, + return_softmax=True, + axis=1) + + loss = loss * mask + avg_loss = fluid.layers.mean(loss) / ( + fluid.layers.mean(mask) + self.EPS) + + label.stop_gradient = True + mask.stop_gradient = True + return avg_loss + + +class UnetEncoder(fluid.dygraph.Layer): + def __init__(self): + super().__init__() + self.double_conv = DoubleConv(3, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + self.down4 = Down(512, 512) + + def forward(self, x): + short_cuts = [] + x = self.double_conv(x) + short_cuts.append(x) + x = self.down1(x) + short_cuts.append(x) + x = self.down2(x) + short_cuts.append(x) + x = self.down3(x) + short_cuts.append(x) + x = self.down4(x) + return x, short_cuts + + +class UnetDecode(fluid.dygraph.Layer): + def __init__(self): + super().__init__() + self.up1 = Up(512, 256) + self.up2 = Up(256, 128) + self.up3 = Up(128, 64) + self.up4 = Up(64, 64) + + def forward(self, x, short_cuts): + x = self.up1(x, short_cuts[3]) + x = self.up2(x, short_cuts[2]) + x = self.up3(x, short_cuts[1]) + x = self.up4(x, short_cuts[0]) + return x + + +class DoubleConv(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + self.conv0 = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1) + self.bn0 = BatchNorm(num_channels=num_filters) + self.conv1 = Conv2D( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1) + self.bn1 = BatchNorm(num_channels=num_filters) + + def forward(self, x): + x = self.conv0(x) + x = self.bn0(x) + x = fluid.layers.relu(x) + x = self.conv1(x) + x = self.bn1(x) + x = fluid.layers.relu(x) + return x + + +class Down(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + self.max_pool = Pool2D( + pool_size=2, pool_type='max', pool_stride=2, pool_padding=0) + self.double_conv = DoubleConv(num_channels, num_filters) + + def forward(self, x): + x = self.max_pool(x) + x = self.double_conv(x) + return x + + +class Up(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super().__init__() + self.double_conv = DoubleConv(2 * num_channels, num_filters) + + def forward(self, x, short_cut): + short_cut_shape = fluid.layers.shape(short_cut) + x = fluid.layers.resize_bilinear(x, short_cut_shape[2:]) + x = fluid.layers.concat([x, short_cut], axis=1) + x = self.double_conv(x) + return x + + +class GetLogit(fluid.dygraph.Layer): + def __init__(self, num_channels, num_classes): + super().__init__() + self.conv = Conv2D( + num_channels=num_channels, + num_filters=num_classes, + filter_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = self.conv(x) + return x diff --git a/dygraph/train.py b/dygraph/train.py new file mode 100644 index 0000000000000000000000000000000000000000..88b1ccb64bcb7f9862a4f81a17ee1cd392db36ae --- /dev/null +++ b/dygraph/train.py @@ -0,0 +1,253 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle.fluid as fluid +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.io import DataLoader +from paddle.incubate.hapi.distributed import DistributedBatchSampler + +from datasets import OpticDiscSeg, Cityscapes +import transforms as T +import models +import utils.logging as logging +from utils import get_environ_info +from utils import load_pretrained_model +from val import evaluate + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model training') + + # params of model + parser.add_argument( + '--model_name', + dest='model_name', + help="Model type for traing, which is one of ('UNet')", + type=str, + default='UNet') + + # params of dataset + parser.add_argument( + '--dataset', + dest='dataset', + help= + "The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')", + type=str, + default='OpticDiscSeg') + + # params of training + parser.add_argument( + "--input_size", + dest="input_size", + help="The image size for net inputs.", + nargs=2, + default=[512, 512], + type=int) + parser.add_argument( + '--num_epochs', + dest='num_epochs', + help='Number epochs for training', + type=int, + default=100) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size of one gpu or cpu', + type=int, + default=2) + parser.add_argument( + '--learning_rate', + dest='learning_rate', + help='Learning rate', + type=float, + default=0.01) + parser.add_argument( + '--pretrained_model', + dest='pretrained_model', + help='The path of pretrained weight', + type=str, + default=None) + parser.add_argument( + '--save_interval_epochs', + dest='save_interval_epochs', + help='The interval epochs for save a model snapshot', + type=int, + default=5) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='The directory for saving the model snapshot', + type=str, + default='./output') + parser.add_argument( + '--num_workers', + dest='num_workers', + help='Num workers for data loader', + type=int, + default=0) + parser.add_argument( + '--do_eval', + dest='do_eval', + help='Eval while training', + action='store_true') + + return parser.parse_args() + + +def train(model, + train_dataset, + places=None, + eval_dataset=None, + optimizer=None, + save_dir='output', + num_epochs=100, + batch_size=2, + pretrained_model=None, + save_interval_epochs=1, + num_classes=None, + num_workers=8): + ignore_index = model.ignore_index + nranks = ParallelEnv().nranks + + load_pretrained_model(model, pretrained_model) + + if not os.path.isdir(save_dir): + if os.path.exists(save_dir): + os.remove(save_dir) + os.makedirs(save_dir) + + if nranks > 1: + strategy = fluid.dygraph.prepare_context() + model_parallel = fluid.dygraph.DataParallel(model, strategy) + + batch_sampler = DistributedBatchSampler( + train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + places=places, + num_workers=num_workers, + return_list=True, + ) + + num_steps_each_epoch = len(train_dataset) // batch_size + + for epoch in range(num_epochs): + for step, data in enumerate(loader): + images = data[0] + labels = data[1].astype('int64') + if nranks > 1: + loss = model_parallel(images, labels, mode='train') + loss = model_parallel.scale_loss(loss) + loss.backward() + model_parallel.apply_collective_grads() + else: + loss = model(images, labels, mode='train') + loss.backward() + optimizer.minimize(loss) + model.clear_gradients() + logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( + epoch + 1, num_epochs, step + 1, len(batch_sampler), + loss.numpy())) + + if ((epoch + 1) % save_interval_epochs == 0 + or num_steps_each_epoch == num_epochs - 1 + ) and ParallelEnv().local_rank == 0: + current_save_dir = os.path.join(save_dir, + "epoch_{}".format(epoch + 1)) + if not os.path.isdir(current_save_dir): + os.makedirs(current_save_dir) + fluid.save_dygraph(model.state_dict(), + os.path.join(current_save_dir, 'model')) + + if eval_dataset is not None: + evaluate( + model, + eval_dataset, + places=places, + model_dir=current_save_dir, + num_classes=num_classes, + batch_size=batch_size, + ignore_index=ignore_index, + epoch_id=epoch + 1) + model.train() + + +def main(args): + env_info = get_environ_info() + places = fluid.CUDAPlace(ParallelEnv().dev_id) \ + if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ + else fluid.CPUPlace() + + if args.dataset.lower() == 'opticdiscseg': + dataset = OpticDiscSeg + elif args.dataset.lower() == 'cityscapes': + dataset = Cityscapes + else: + raise Exception( + "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')" + ) + + with fluid.dygraph.guard(places): + # Creat dataset reader + train_transforms = T.Compose([ + T.Resize(args.input_size), + T.RandomHorizontalFlip(), + T.Normalize() + ]) + train_dataset = dataset(transforms=train_transforms, mode='train') + + eval_dataset = None + if args.do_eval: + eval_transforms = T.Compose( + [T.Resize(args.input_size), + T.Normalize()]) + eval_dataset = dataset(transforms=eval_transforms, mode='eval') + + if args.model_name == 'UNet': + model = models.UNet( + num_classes=train_dataset.num_classes, ignore_index=255) + + # Creat optimizer + num_steps_each_epoch = len(train_dataset) // args.batch_size + decay_step = args.num_epochs * num_steps_each_epoch + lr_decay = fluid.layers.polynomial_decay( + args.learning_rate, decay_step, end_learning_rate=0, power=0.9) + optimizer = fluid.optimizer.Momentum( + lr_decay, + momentum=0.9, + parameter_list=model.parameters(), + regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5)) + + train( + model, + train_dataset, + places=places, + eval_dataset=eval_dataset, + optimizer=optimizer, + save_dir=args.save_dir, + num_epochs=args.num_epochs, + batch_size=args.batch_size, + pretrained_model=args.pretrained_model, + save_interval_epochs=args.save_interval_epochs, + num_classes=train_dataset.num_classes, + num_workers=args.num_workers) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/dygraph/transforms/__init__.py b/dygraph/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1d5ae80aeb1eb77ac672b1cbcfedcbfbd643c4 --- /dev/null +++ b/dygraph/transforms/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transforms import * +from . import functional diff --git a/dygraph/transforms/functional.py b/dygraph/transforms/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5a9b10db15edb05692c8aa4249912652e0a745 --- /dev/null +++ b/dygraph/transforms/functional.py @@ -0,0 +1,99 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +from PIL import Image, ImageEnhance + + +def normalize(im, mean, std): + im = im.astype(np.float32, copy=False) / 255.0 + im -= mean + im /= std + return im + + +def permute(im): + im = np.transpose(im, (2, 0, 1)) + return im + + +def resize(im, target_size=608, interp=cv2.INTER_LINEAR): + if isinstance(target_size, list) or isinstance(target_size, tuple): + w = target_size[0] + h = target_size[1] + else: + w = target_size + h = target_size + im = cv2.resize(im, (w, h), interpolation=interp) + return im + + +def resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR): + value = max(im.shape[0], im.shape[1]) + scale = float(long_size) / float(value) + resized_width = int(round(im.shape[1] * scale)) + resized_height = int(round(im.shape[0] * scale)) + + im = cv2.resize( + im, (resized_width, resized_height), interpolation=interpolation) + return im + + +def horizontal_flip(im): + if len(im.shape) == 3: + im = im[:, ::-1, :] + elif len(im.shape) == 2: + im = im[:, ::-1] + return im + + +def vertical_flip(im): + if len(im.shape) == 3: + im = im[::-1, :, :] + elif len(im.shape) == 2: + im = im[::-1, :] + return im + + +def brightness(im, brightness_lower, brightness_upper): + brightness_delta = np.random.uniform(brightness_lower, brightness_upper) + im = ImageEnhance.Brightness(im).enhance(brightness_delta) + return im + + +def contrast(im, contrast_lower, contrast_upper): + contrast_delta = np.random.uniform(contrast_lower, contrast_upper) + im = ImageEnhance.Contrast(im).enhance(contrast_delta) + return im + + +def saturation(im, saturation_lower, saturation_upper): + saturation_delta = np.random.uniform(saturation_lower, saturation_upper) + im = ImageEnhance.Color(im).enhance(saturation_delta) + return im + + +def hue(im, hue_lower, hue_upper): + hue_delta = np.random.uniform(hue_lower, hue_upper) + im = np.array(im.convert('HSV')) + im[:, :, 0] = im[:, :, 0] + hue_delta + im = Image.fromarray(im, mode='HSV').convert('RGB') + return im + + +def rotate(im, rotate_lower, rotate_upper): + rotate_delta = np.random.uniform(rotate_lower, rotate_upper) + im = im.rotate(int(rotate_delta)) + return im diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..38c3be18a2ae885bfa6238304a614935401a6330 --- /dev/null +++ b/dygraph/transforms/transforms.py @@ -0,0 +1,857 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .functional import * +import random +import numpy as np +from PIL import Image +import cv2 +from collections import OrderedDict + + +class Compose: + """根据数据预处理/增强算子对输入数据进行操作。 + 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。 + + Args: + transforms (list): 数据预处理/增强算子。 + to_rgb (bool): 是否转化为rgb通道格式 + + Raises: + TypeError: transforms不是list对象 + ValueError: transforms元素个数小于1。 + + """ + def __init__(self, transforms, to_rgb=True): + if not isinstance(transforms, list): + raise TypeError('The transforms must be a list!') + if len(transforms) < 1: + raise ValueError('The length of transforms ' + \ + 'must be equal or larger than 1!') + self.transforms = transforms + self.to_rgb = to_rgb + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (str/np.ndarray): 图像路径/图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息,dict中的字段如下: + - shape_before_resize (tuple): 图像resize之前的大小(h, w)。 + - shape_before_padding (tuple): 图像padding之前的大小(h, w)。 + label (str/np.ndarray): 标注图像路径/标注图像np.ndarray数据。 + + Returns: + tuple: 根据网络所需字段所组成的tuple;字段由transforms中的最后一个数据预处理操作决定。 + """ + + if im_info is None: + im_info = dict() + if isinstance(im, str): + im = cv2.imread(im).astype('float32') + if isinstance(label, str): + label = np.asarray(Image.open(label)) + if im is None: + raise ValueError('Can\'t read The image file {}!'.format(im)) + if self.to_rgb: + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + for op in self.transforms: + outputs = op(im, im_info, label) + im = outputs[0] + if len(outputs) >= 2: + im_info = outputs[1] + if len(outputs) == 3: + label = outputs[2] + im = permute(im) + if len(outputs) == 3: + label = label[np.newaxis, :, :] + return (im, im_info, label) + + +class RandomHorizontalFlip: + """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。 + + Args: + prob (float): 随机水平翻转的概率。默认值为0.5。 + + """ + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if random.random() < self.prob: + im = horizontal_flip(im) + if label is not None: + label = horizontal_flip(label) + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomVerticalFlip: + """以一定的概率对图像进行垂直翻转。当存在标注图像时,则同步进行翻转。 + + Args: + prob (float): 随机垂直翻转的概率。默认值为0.1。 + """ + def __init__(self, prob=0.1): + self.prob = prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if random.random() < self.prob: + im = vertical_flip(im) + if label is not None: + label = vertical_flip(label) + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class Resize: + """调整图像大小(resize)。 + + - 当目标大小(target_size)类型为int时,根据插值方式, + 将图像resize为[target_size, target_size]。 + - 当目标大小(target_size)类型为list或tuple时,根据插值方式, + 将图像resize为target_size。 + 注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。 + + Args: + target_size (int/list/tuple): 短边目标长度。默认为608。 + interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为 + ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"LINEAR"。 + + Raises: + TypeError: 形参数据类型不满足需求。 + ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC', + 'AREA', 'LANCZOS4', 'RANDOM']中。 + """ + + # The interpolation mode + interp_dict = { + 'NEAREST': cv2.INTER_NEAREST, + 'LINEAR': cv2.INTER_LINEAR, + 'CUBIC': cv2.INTER_CUBIC, + 'AREA': cv2.INTER_AREA, + 'LANCZOS4': cv2.INTER_LANCZOS4 + } + + def __init__(self, target_size=512, interp='LINEAR'): + self.interp = interp + if not (interp == "RANDOM" or interp in self.interp_dict): + raise ValueError("interp should be one of {}".format( + self.interp_dict.keys())) + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise TypeError( + 'when target is list or tuple, it should include 2 elements, but it is {}' + .format(target_size)) + elif not isinstance(target_size, int): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List or tuple, now is {}" + .format(type(target_size))) + + self.target_size = target_size + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict, 可选): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + 其中,im_info跟新字段为: + -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 + + Raises: + TypeError: 形参数据类型不满足需求。 + ValueError: 数据长度不匹配。 + """ + if im_info is None: + im_info = OrderedDict() + im_info['shape_before_resize'] = im.shape[:2] + if not isinstance(im, np.ndarray): + raise TypeError("Resize: image type is not numpy.") + if len(im.shape) != 3: + raise ValueError('Resize: image is not 3-dimensional.') + if self.interp == "RANDOM": + interp = random.choice(list(self.interp_dict.keys())) + else: + interp = self.interp + im = resize(im, self.target_size, self.interp_dict[interp]) + if label is not None: + label = resize(label, self.target_size, cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ResizeByLong: + """对图像长边resize到固定值,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 + + Args: + long_size (int): resize后图像的长边大小。 + """ + def __init__(self, long_size): + self.long_size = long_size + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + 其中,im_info新增字段为: + -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 + """ + if im_info is None: + im_info = OrderedDict() + + im_info['shape_before_resize'] = im.shape[:2] + im = resize_long(im, self.long_size) + if label is not None: + label = resize_long(label, self.long_size, cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ResizeRangeScaling: + """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 + + Args: + min_value (int): 图像长边resize后的最小值。默认值400。 + max_value (int): 图像长边resize后的最大值。默认值600。 + + Raises: + ValueError: min_value大于max_value + """ + def __init__(self, min_value=400, max_value=600): + if min_value > max_value: + raise ValueError('min_value must be less than max_value, ' + 'but they are {} and {}.'.format( + min_value, max_value)) + self.min_value = min_value + self.max_value = max_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.min_value == self.max_value: + random_size = self.max_value + else: + random_size = int( + np.random.uniform(self.min_value, self.max_value) + 0.5) + im = resize_long(im, random_size, cv2.INTER_LINEAR) + if label is not None: + label = resize_long(label, random_size, cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class ResizeStepScaling: + """对图像按照某一个比例resize,这个比例以scale_step_size为步长 + 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。 + + Args: + min_scale_factor(float), resize最小尺度。默认值0.75。 + max_scale_factor (float), resize最大尺度。默认值1.25。 + scale_step_size (float), resize尺度范围间隔。默认值0.25。 + + Raises: + ValueError: min_scale_factor大于max_scale_factor + """ + def __init__(self, + min_scale_factor=0.75, + max_scale_factor=1.25, + scale_step_size=0.25): + if min_scale_factor > max_scale_factor: + raise ValueError( + 'min_scale_factor must be less than max_scale_factor, ' + 'but they are {} and {}.'.format(min_scale_factor, + max_scale_factor)) + self.min_scale_factor = min_scale_factor + self.max_scale_factor = max_scale_factor + self.scale_step_size = scale_step_size + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.min_scale_factor == self.max_scale_factor: + scale_factor = self.min_scale_factor + + elif self.scale_step_size == 0: + scale_factor = np.random.uniform(self.min_scale_factor, + self.max_scale_factor) + + else: + num_steps = int((self.max_scale_factor - self.min_scale_factor) / + self.scale_step_size + 1) + scale_factors = np.linspace(self.min_scale_factor, + self.max_scale_factor, + num_steps).tolist() + np.random.shuffle(scale_factors) + scale_factor = scale_factors[0] + w = int(round(scale_factor * im.shape[1])) + h = int(round(scale_factor * im.shape[0])) + + im = resize(im, (w, h), cv2.INTER_LINEAR) + if label is not None: + label = resize(label, (w, h), cv2.INTER_NEAREST) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class Normalize: + """对图像进行标准化。 + 1.尺度缩放到 [0,1]。 + 2.对图像进行减均值除以标准差操作。 + + Args: + mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。 + std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。 + + Raises: + ValueError: mean或std不是list对象。std包含0。 + """ + def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + self.mean = mean + self.std = std + if not (isinstance(self.mean, list) and isinstance(self.std, list)): + raise ValueError("{}: input type is invalid.".format(self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im = normalize(im, mean, std) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class Padding: + """对图像或标注图像进行padding,padding方向为右和下。 + 根据提供的值对图像或标注图像进行padding操作。 + + Args: + target_size (int|list|tuple): padding后图像的大小。 + im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + label_padding_value (int): 标注图像padding的值。默认值为255。 + + Raises: + TypeError: target_size不是int|list|tuple。 + ValueError: target_size为list|tuple时元素个数不等于2。 + """ + def __init__(self, + target_size, + im_padding_value=[127.5, 127.5, 127.5], + label_padding_value=255): + if isinstance(target_size, list) or isinstance(target_size, tuple): + if len(target_size) != 2: + raise ValueError( + 'when target is list or tuple, it should include 2 elements, but it is {}' + .format(target_size)) + elif not isinstance(target_size, int): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List or tuple, now is {}" + .format(type(target_size))) + self.target_size = target_size + self.im_padding_value = im_padding_value + self.label_padding_value = label_padding_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + 其中,im_info新增字段为: + -shape_before_padding (tuple): 保存padding之前图像的形状(h, w)。 + + Raises: + ValueError: 输入图像im或label的形状大于目标值 + """ + if im_info is None: + im_info = OrderedDict() + im_info['shape_before_padding'] = im.shape[:2] + + im_height, im_width = im.shape[0], im.shape[1] + if isinstance(self.target_size, int): + target_height = self.target_size + target_width = self.target_size + else: + target_height = self.target_size[1] + target_width = self.target_size[0] + pad_height = target_height - im_height + pad_width = target_width - im_width + if pad_height < 0 or pad_width < 0: + raise ValueError( + 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' + .format(im_width, im_height, target_width, target_height)) + else: + im = cv2.copyMakeBorder(im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) + if label is not None: + label = cv2.copyMakeBorder(label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomPaddingCrop: + """对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作。 + + Args: + crop_size (int|list|tuple): 裁剪图像大小。默认为512。 + im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + label_padding_value (int): 标注图像padding的值。默认值为255。 + + Raises: + TypeError: crop_size不是int/list/tuple。 + ValueError: target_size为list/tuple时元素个数不等于2。 + """ + def __init__(self, + crop_size=512, + im_padding_value=[127.5, 127.5, 127.5], + label_padding_value=255): + if isinstance(crop_size, list) or isinstance(crop_size, tuple): + if len(crop_size) != 2: + raise ValueError( + 'when crop_size is list or tuple, it should include 2 elements, but it is {}' + .format(crop_size)) + elif not isinstance(crop_size, int): + raise TypeError( + "Type of crop_size is invalid. Must be Integer or List or tuple, now is {}" + .format(type(crop_size))) + self.crop_size = crop_size + self.im_padding_value = im_padding_value + self.label_padding_value = label_padding_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if isinstance(self.crop_size, int): + crop_width = self.crop_size + crop_height = self.crop_size + else: + crop_width = self.crop_size[0] + crop_height = self.crop_size[1] + + img_height = im.shape[0] + img_width = im.shape[1] + + if img_height == crop_height and img_width == crop_width: + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + else: + pad_height = max(crop_height - img_height, 0) + pad_width = max(crop_width - img_width, 0) + if (pad_height > 0 or pad_width > 0): + im = cv2.copyMakeBorder(im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) + if label is not None: + label = cv2.copyMakeBorder(label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) + img_height = im.shape[0] + img_width = im.shape[1] + + if crop_height > 0 and crop_width > 0: + h_off = np.random.randint(img_height - crop_height + 1) + w_off = np.random.randint(img_width - crop_width + 1) + + im = im[h_off:(crop_height + h_off), w_off:(w_off + + crop_width), :] + if label is not None: + label = label[h_off:(crop_height + + h_off), w_off:(w_off + crop_width)] + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomBlur: + """以一定的概率对图像进行高斯模糊。 + + Args: + prob (float): 图像模糊概率。默认为0.1。 + """ + def __init__(self, prob=0.1): + self.prob = prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.prob <= 0: + n = 0 + elif self.prob >= 1: + n = 1 + else: + n = int(1.0 / self.prob) + if n > 0: + if np.random.randint(0, n) == 0: + radius = np.random.randint(3, 10) + if radius % 2 != 1: + radius = radius + 1 + if radius > 9: + radius = 9 + im = cv2.GaussianBlur(im, (radius, radius), 0, 0) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomRotation: + """对图像进行随机旋转。 + 在不超过最大旋转角度的情况下,图像进行随机旋转,当存在标注图像时,同步进行, + 并对旋转后的图像和标注图像进行相应的padding。 + + Args: + max_rotation (float): 最大旋转角度。默认为15度。 + im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + label_padding_value (int): 标注图像padding的值。默认为255。 + + """ + def __init__(self, + max_rotation=15, + im_padding_value=[127.5, 127.5, 127.5], + label_padding_value=255): + self.max_rotation = max_rotation + self.im_padding_value = im_padding_value + self.label_padding_value = label_padding_value + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.max_rotation > 0: + (h, w) = im.shape[:2] + do_rotation = np.random.uniform(-self.max_rotation, + self.max_rotation) + pc = (w // 2, h // 2) + r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0) + cos = np.abs(r[0, 0]) + sin = np.abs(r[0, 1]) + + nw = int((h * sin) + (w * cos)) + nh = int((h * cos) + (w * sin)) + + (cx, cy) = pc + r[0, 2] += (nw / 2) - cx + r[1, 2] += (nh / 2) - cy + dsize = (nw, nh) + im = cv2.warpAffine(im, + r, + dsize=dsize, + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.im_padding_value) + label = cv2.warpAffine(label, + r, + dsize=dsize, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.label_padding_value) + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomScaleAspect: + """裁剪并resize回原始尺寸的图像和标注图像。 + 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。 + + Args: + min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 + aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 + """ + def __init__(self, min_scale=0.5, aspect_ratio=0.33): + self.min_scale = min_scale + self.aspect_ratio = aspect_ratio + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + if self.min_scale != 0 and self.aspect_ratio != 0: + img_height = im.shape[0] + img_width = im.shape[1] + for i in range(0, 10): + area = img_height * img_width + target_area = area * np.random.uniform(self.min_scale, 1.0) + aspectRatio = np.random.uniform(self.aspect_ratio, + 1.0 / self.aspect_ratio) + + dw = int(np.sqrt(target_area * 1.0 * aspectRatio)) + dh = int(np.sqrt(target_area * 1.0 / aspectRatio)) + if (np.random.randint(10) < 5): + tmp = dw + dw = dh + dh = tmp + + if (dh < img_height and dw < img_width): + h1 = np.random.randint(0, img_height - dh) + w1 = np.random.randint(0, img_width - dw) + + im = im[h1:(h1 + dh), w1:(w1 + dw), :] + label = label[h1:(h1 + dh), w1:(w1 + dw)] + im = cv2.resize(im, (img_width, img_height), + interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, (img_width, img_height), + interpolation=cv2.INTER_NEAREST) + break + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + +class RandomDistort: + """对图像进行随机失真。 + + 1. 对变换的操作顺序进行随机化操作。 + 2. 按照1中的顺序以一定的概率对图像进行随机像素内容变换。 + + Args: + brightness_range (float): 明亮度因子的范围。默认为0.5。 + brightness_prob (float): 随机调整明亮度的概率。默认为0.5。 + contrast_range (float): 对比度因子的范围。默认为0.5。 + contrast_prob (float): 随机调整对比度的概率。默认为0.5。 + saturation_range (float): 饱和度因子的范围。默认为0.5。 + saturation_prob (float): 随机调整饱和度的概率。默认为0.5。 + hue_range (int): 色调因子的范围。默认为18。 + hue_prob (float): 随机调整色调的概率。默认为0.5。 + """ + def __init__(self, + brightness_range=0.5, + brightness_prob=0.5, + contrast_range=0.5, + contrast_prob=0.5, + saturation_range=0.5, + saturation_prob=0.5, + hue_range=18, + hue_prob=0.5): + self.brightness_range = brightness_range + self.brightness_prob = brightness_prob + self.contrast_range = contrast_range + self.contrast_prob = contrast_prob + self.saturation_range = saturation_range + self.saturation_prob = saturation_prob + self.hue_range = hue_range + self.hue_prob = hue_prob + + def __call__(self, im, im_info=None, label=None): + """ + Args: + im (np.ndarray): 图像np.ndarray数据。 + im_info (dict): 存储与图像相关的信息。 + label (np.ndarray): 标注图像np.ndarray数据。 + + Returns: + tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; + 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 + 存储与图像相关信息的字典和标注图像np.ndarray数据。 + """ + brightness_lower = 1 - self.brightness_range + brightness_upper = 1 + self.brightness_range + contrast_lower = 1 - self.contrast_range + contrast_upper = 1 + self.contrast_range + saturation_lower = 1 - self.saturation_range + saturation_upper = 1 + self.saturation_range + hue_lower = -self.hue_range + hue_upper = self.hue_range + ops = [brightness, contrast, saturation, hue] + random.shuffle(ops) + params_dict = { + 'brightness': { + 'brightness_lower': brightness_lower, + 'brightness_upper': brightness_upper + }, + 'contrast': { + 'contrast_lower': contrast_lower, + 'contrast_upper': contrast_upper + }, + 'saturation': { + 'saturation_lower': saturation_lower, + 'saturation_upper': saturation_upper + }, + 'hue': { + 'hue_lower': hue_lower, + 'hue_upper': hue_upper + } + } + prob_dict = { + 'brightness': self.brightness_prob, + 'contrast': self.contrast_prob, + 'saturation': self.saturation_prob, + 'hue': self.hue_prob + } + im = im.astype('uint8') + im = Image.fromarray(im) + for id in range(4): + params = params_dict[ops[id].__name__] + prob = prob_dict[ops[id].__name__] + params['im'] = im + if np.random.uniform(0, 1) < prob: + im = ops[id](**params) + im = np.asarray(im).astype('float32') + if label is None: + return (im, im_info) + else: + return (im, im_info, label) diff --git a/dygraph/utils/__init__.py b/dygraph/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7579cf7f0ed9f051b154d7bc2f99fc25ac246d4a --- /dev/null +++ b/dygraph/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import logging +from . import download +from .metrics import ConfusionMatrix +from .utils import * diff --git a/dygraph/utils/download.py b/dygraph/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..848675fa893de81475c418393c0d9e2c262a1a48 --- /dev/null +++ b/dygraph/utils/download.py @@ -0,0 +1,135 @@ +import os +import sys +import time +import requests +import tarfile +import zipfile +import shutil +import functools + +lasttime = time.time() +FLUSH_INTERVAL = 0.1 + + +def progress(str, end=False): + global lasttime + if end: + str += "\n" + lasttime = 0 + if time.time() - lasttime >= FLUSH_INTERVAL: + sys.stdout.write("\r%s" % str) + lasttime = time.time() + sys.stdout.flush() + + +def _download_file(url, savepath, print_progress): + r = requests.get(url, stream=True) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(savepath, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(savepath, 'wb') as f: + dl = 0 + total_length = int(total_length) + starttime = time.time() + if print_progress: + print("Downloading %s" % os.path.basename(savepath)) + for data in r.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + if print_progress: + done = int(50 * dl / total_length) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * dl) / total_length)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + +def _uncompress_file_zip(filepath, extrapath): + files = zipfile.ZipFile(filepath, 'r') + filelist = files.namelist() + rootpath = filelist[0] + total_num = len(filelist) + for index, file in enumerate(filelist): + files.extract(file, extrapath) + yield total_num, index, rootpath + files.close() + yield total_num, index, rootpath + + +def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): + files = tarfile.open(filepath, mode) + filelist = files.getnames() + total_num = len(filelist) + rootpath = filelist[0] + for index, file in enumerate(filelist): + files.extract(file, extrapath) + yield total_num, index, rootpath + files.close() + yield total_num, index, rootpath + + +def _uncompress_file(filepath, extrapath, delete_file, print_progress): + if print_progress: + print("Uncompress %s" % os.path.basename(filepath)) + + if filepath.endswith("zip"): + handler = _uncompress_file_zip + elif filepath.endswith("tgz"): + handler = _uncompress_file_tar + else: + handler = functools.partial(_uncompress_file_tar, mode="r") + + for total_num, index, rootpath in handler(filepath, extrapath): + if print_progress: + done = int(50 * float(index) / total_num) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * index) / total_num)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + if delete_file: + os.remove(filepath) + + return rootpath + + +def download_file_and_uncompress(url, + savepath=None, + extrapath=None, + extraname=None, + print_progress=True, + cover=False, + delete_file=True): + if savepath is None: + savepath = "." + + if extrapath is None: + extrapath = "." + + savename = url.split("/")[-1] + savepath = os.path.join(savepath, savename) + savename = ".".join(savename.split(".")[:-1]) + savename = os.path.join(extrapath, savename) + extraname = savename if extraname is None else os.path.join( + extrapath, extraname) + + if cover: + if os.path.exists(savepath): + shutil.rmtree(savepath) + if os.path.exists(savename): + shutil.rmtree(savename) + if os.path.exists(extraname): + shutil.rmtree(extraname) + + if not os.path.exists(extraname): + if not os.path.exists(savename): + if not os.path.exists(savepath): + _download_file(url, savepath, print_progress) + savename = _uncompress_file(savepath, extrapath, delete_file, + print_progress) + savename = os.path.join(extrapath, savename) + shutil.move(savename, extraname) + return savename diff --git a/dygraph/utils/logging.py b/dygraph/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..015948f65090e40895f6d4a72a75a11f2b155447 --- /dev/null +++ b/dygraph/utils/logging.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import sys + +from paddle.fluid.dygraph.parallel import ParallelEnv + +levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} +log_level = 2 + + +def log(level=2, message=""): + if ParallelEnv().local_rank == 0: + current_time = time.time() + time_array = time.localtime(current_time) + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) + if log_level >= level: + print( + "{} [{}]\t{}".format(current_time, levels[level], + message).encode("utf-8").decode("latin1")) + sys.stdout.flush() + + +def debug(message=""): + log(level=3, message=message) + + +def info(message=""): + log(level=2, message=message) + + +def warning(message=""): + log(level=1, message=message) + + +def error(message=""): + log(level=0, message=message) diff --git a/dygraph/utils/metrics.py b/dygraph/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b107cbd57a936fb909086567fc8b703fb86963b7 --- /dev/null +++ b/dygraph/utils/metrics.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +from scipy.sparse import csr_matrix + + +class ConfusionMatrix(object): + """ + Confusion Matrix for segmentation evaluation + """ + + def __init__(self, num_classes=2, streaming=False): + self.confusion_matrix = np.zeros([num_classes, num_classes], + dtype='int64') + self.num_classes = num_classes + self.streaming = streaming + + def calculate(self, pred, label, ignore=None): + # If not in streaming mode, clear matrix everytime when call `calculate` + if not self.streaming: + self.zero_matrix() + + label = np.transpose(label, (0, 2, 3, 1)) + ignore = np.transpose(ignore, (0, 2, 3, 1)) + mask = np.array(ignore) == 1 + + label = np.asarray(label)[mask] + pred = np.asarray(pred)[mask] + one = np.ones_like(pred) + # Accumuate ([row=label, col=pred], 1) into sparse matrix + spm = csr_matrix((one, (label, pred)), + shape=(self.num_classes, self.num_classes)) + spm = spm.todense() + self.confusion_matrix += spm + + def zero_matrix(self): + """ Clear confusion matrix """ + self.confusion_matrix = np.zeros([self.num_classes, self.num_classes], + dtype='int64') + + def mean_iou(self): + iou_list = [] + avg_iou = 0 + # TODO: use numpy sum axis api to simpliy + vji = np.zeros(self.num_classes, dtype=int) + vij = np.zeros(self.num_classes, dtype=int) + for j in range(self.num_classes): + v_j = 0 + for i in range(self.num_classes): + v_j += self.confusion_matrix[j][i] + vji[j] = v_j + + for i in range(self.num_classes): + v_i = 0 + for j in range(self.num_classes): + v_i += self.confusion_matrix[j][i] + vij[i] = v_i + + for c in range(self.num_classes): + total = vji[c] + vij[c] - self.confusion_matrix[c][c] + if total == 0: + iou = 0 + else: + iou = float(self.confusion_matrix[c][c]) / total + avg_iou += iou + iou_list.append(iou) + avg_iou = float(avg_iou) / float(self.num_classes) + return np.array(iou_list), avg_iou + + def accuracy(self): + total = self.confusion_matrix.sum() + total_right = 0 + for c in range(self.num_classes): + total_right += self.confusion_matrix[c][c] + if total == 0: + avg_acc = 0 + else: + avg_acc = float(total_right) / total + + vij = np.zeros(self.num_classes, dtype=int) + for i in range(self.num_classes): + v_i = 0 + for j in range(self.num_classes): + v_i += self.confusion_matrix[j][i] + vij[i] = v_i + + acc_list = [] + for c in range(self.num_classes): + if vij[c] == 0: + acc = 0 + else: + acc = self.confusion_matrix[c][c] / float(vij[c]) + acc_list.append(acc) + return np.array(acc_list), avg_acc + + def kappa(self): + vji = np.zeros(self.num_classes) + vij = np.zeros(self.num_classes) + for j in range(self.num_classes): + v_j = 0 + for i in range(self.num_classes): + v_j += self.confusion_matrix[j][i] + vji[j] = v_j + + for i in range(self.num_classes): + v_i = 0 + for j in range(self.num_classes): + v_i += self.confusion_matrix[j][i] + vij[i] = v_i + + total = self.confusion_matrix.sum() + + # avoid spillovers + # TODO: is it reasonable to hard code 10000.0? + total = float(total) / 10000.0 + vji = vji / 10000.0 + vij = vij / 10000.0 + + tp = 0 + tc = 0 + for c in range(self.num_classes): + tp += vji[c] * vij[c] + tc += self.confusion_matrix[c][c] + + tc = tc / 10000.0 + pe = tp / (total * total) + po = tc / total + + kappa = (po - pe) / (1 - pe) + return kappa diff --git a/dygraph/utils/utils.py b/dygraph/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a450b352e0dcf98c1eeaa093878c9b3ba649dfd --- /dev/null +++ b/dygraph/utils/utils.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import math +import cv2 +import paddle.fluid as fluid + +from . import logging + + +def seconds_to_hms(seconds): + h = math.floor(seconds / 3600) + m = math.floor((seconds - h * 3600) / 60) + s = int(seconds - h * 3600 - m * 60) + hms_str = "{}:{}:{}".format(h, m, s) + return hms_str + + +def get_environ_info(): + info = dict() + info['place'] = 'cpu' + info['num'] = int(os.environ.get('CPU_NUM', 1)) + if os.environ.get('CUDA_VISIBLE_DEVICES', None) != "": + if hasattr(fluid.core, 'get_cuda_device_count'): + gpu_num = 0 + try: + gpu_num = fluid.core.get_cuda_device_count() + except: + os.environ['CUDA_VISIBLE_DEVICES'] = '' + pass + if gpu_num > 0: + info['place'] = 'cuda' + info['num'] = fluid.core.get_cuda_device_count() + return info + + +def load_pretrained_model(model, pretrained_model): + if pretrained_model is not None: + logging.info('Load pretrained model!') + if os.path.exists(pretrained_model): + ckpt_path = os.path.join(pretrained_model, 'model') + para_state_dict, _ = fluid.load_dygraph(ckpt_path) + model_state_dict = model.state_dict() + keys = model_state_dict.keys() + num_params_loaded = 0 + for k in keys: + if k not in para_state_dict: + logging.warning("{} is not in pretrained model".format(k)) + elif list(para_state_dict[k].shape) != list( + model_state_dict[k].shape): + logging.warning( + "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})" + .format(k, para_state_dict[k].shape, + model_state_dict[k].shape)) + else: + model_state_dict[k] = para_state_dict[k] + num_params_loaded += 1 + model.set_dict(model_state_dict) + logging.info("There are {}/{} varaibles are loaded.".format( + num_params_loaded, len(model_state_dict))) + + else: + raise ValueError( + 'The pretrained model directory is not Found: {}'.formnat( + pretrained_model)) + + +def visualize(image, result, save_dir=None, weight=0.6): + """ + Convert segment result to color image, and save added image. + Args: + image: the path of origin image + result: the predict result of image + save_dir: the directory for saving visual image + weight: the image weight of visual image, and the result weight is (1 - weight) + """ + color_map = get_color_map_list(256) + color_map = np.array(color_map).astype("uint8") + # Use OpenCV LUT for color mapping + c1 = cv2.LUT(result, color_map[:, 0]) + c2 = cv2.LUT(result, color_map[:, 1]) + c3 = cv2.LUT(result, color_map[:, 2]) + pseudo_img = np.dstack((c1, c2, c3)) + + im = cv2.imread(image) + vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0) + + if save_dir is not None: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + image_name = os.path.split(image)[-1] + out_path = os.path.join(save_dir, image_name) + cv2.imwrite(out_path, vis_result) + else: + return vis_result + + +def get_color_map_list(num_classes): + """ Returns the color map for visualizing the segmentation mask, + which can support arbitrary number of classes. + Args: + num_classes: Number of classes + Returns: + The color map + """ + num_classes += 1 + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + color_map = color_map[1:] + return color_map diff --git a/dygraph/val.py b/dygraph/val.py new file mode 100644 index 0000000000000000000000000000000000000000..358bcd83b3e32cc0f86b334ec5b09748e593ee1e --- /dev/null +++ b/dygraph/val.py @@ -0,0 +1,161 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import math + +from paddle.fluid.dygraph.base import to_variable +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.io import DataLoader +from paddle.fluid.dataloader import BatchSampler + +from datasets import OpticDiscSeg, Cityscapes +import transforms as T +import models +import utils.logging as logging +from utils import get_environ_info +from utils import ConfusionMatrix + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model evaluation') + + # params of model + parser.add_argument( + '--model_name', + dest='model_name', + help="Model type for evaluation, which is one of ('UNet')", + type=str, + default='UNet') + + # params of dataset + parser.add_argument( + '--dataset', + dest='dataset', + help= + "The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')", + type=str, + default='OpticDiscSeg') + + # params of evaluate + parser.add_argument( + "--input_size", + dest="input_size", + help="The image size for net inputs.", + nargs=2, + default=[512, 512], + type=int) + parser.add_argument( + '--batch_size', + dest='batch_size', + help='Mini batch size', + type=int, + default=2) + parser.add_argument( + '--model_dir', + dest='model_dir', + help='The path of model for evaluation', + type=str, + default=None) + + return parser.parse_args() + + +def evaluate(model, + eval_dataset=None, + places=None, + model_dir=None, + num_classes=None, + batch_size=2, + ignore_index=255, + epoch_id=None): + ckpt_path = os.path.join(model_dir, 'model') + para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path) + model.set_dict(para_state_dict) + model.eval() + + batch_sampler = BatchSampler( + eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + loader = DataLoader( + eval_dataset, + batch_sampler=batch_sampler, + places=places, + return_list=True, + ) + total_steps = math.ceil(len(eval_dataset) * 1.0 / batch_size) + conf_mat = ConfusionMatrix(num_classes, streaming=True) + + logging.info( + "Start to evaluating(total_samples={}, total_steps={})...".format( + len(eval_dataset), total_steps)) + for step, data in enumerate(loader): + images = data[0] + labels = data[1].astype('int64') + pred, _ = model(images, mode='eval') + + pred = pred.numpy() + labels = labels.numpy() + mask = labels != ignore_index + conf_mat.calculate(pred=pred, label=labels, ignore=mask) + _, iou = conf_mat.mean_iou() + + logging.info("[EVAL] Epoch={}, Step={}/{}, iou={}".format( + epoch_id, step + 1, total_steps, iou)) + + category_iou, miou = conf_mat.mean_iou() + category_acc, macc = conf_mat.accuracy() + logging.info("[EVAL] #image={} acc={:.4f} IoU={:.4f}".format( + len(eval_dataset), macc, miou)) + logging.info("[EVAL] Category IoU: " + str(category_iou)) + logging.info("[EVAL] Category Acc: " + str(category_acc)) + logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa())) + + +def main(args): + env_info = get_environ_info() + places = fluid.CUDAPlace(ParallelEnv().dev_id) \ + if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \ + else fluid.CPUPlace() + + if args.dataset.lower() == 'opticdiscseg': + dataset = OpticDiscSeg + elif args.dataset.lower() == 'cityscapes': + dataset = Cityscapes + else: + raise Exception( + "The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')" + ) + + with fluid.dygraph.guard(places): + eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) + eval_dataset = dataset(transforms=eval_transforms, mode='eval') + + if args.model_name == 'UNet': + model = models.UNet(num_classes=eval_dataset.num_classes) + + evaluate( + model, + eval_dataset, + places=places, + model_dir=args.model_dir, + num_classes=eval_dataset.num_classes, + batch_size=args.batch_size) + + +if __name__ == '__main__': + args = parse_args() + main(args)