提交 1a3f44a5 编写于 作者: C chenguowei01

Merge commit 'refs/pull/362/head' of https://github.com/PaddlePaddle/PaddleSeg into dygraph

...@@ -12,4 +12,6 @@ ...@@ -12,4 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import dygraph.models from . import models
\ No newline at end of file from . import datasets
from . import transforms
batch_size: 2
iters: 40000
train_dataset:
type: Cityscapes
dataset_root: datasets/cityscapes
transforms:
- type: RandomHorizontalFlip
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: datasets/cityscapes
transforms:
- type: Normalize
mode: val
model:
type: ocrnet
backbone:
type: HRNet_W18
pretrained: dygraph/pretrained_model/hrnet_w18_ssld/model
num_classes: 19
in_channels: 270
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
type: CrossEntropy
...@@ -44,19 +44,20 @@ class ComponentManager: ...@@ -44,19 +44,20 @@ class ComponentManager:
def __init__(self): def __init__(self):
self._components_dict = dict() self._components_dict = dict()
def __len__(self): def __len__(self):
return len(self._components_dict) return len(self._components_dict)
def __repr__(self): def __repr__(self):
return "{}:{}".format(self.__class__.__name__, list(self._components_dict.keys())) return "{}:{}".format(self.__class__.__name__,
list(self._components_dict.keys()))
def __getitem__(self, item): def __getitem__(self, item):
if item not in self._components_dict.keys(): if item not in self._components_dict.keys():
raise KeyError("{} does not exist in the current {}".format(item, self)) raise KeyError("{} does not exist in the current {}".format(
item, self))
return self._components_dict[item] return self._components_dict[item]
@property @property
def components_dict(self): def components_dict(self):
return self._components_dict return self._components_dict
...@@ -74,7 +75,9 @@ class ComponentManager: ...@@ -74,7 +75,9 @@ class ComponentManager:
# Currently only support class or function type # Currently only support class or function type
if not (inspect.isclass(component) or inspect.isfunction(component)): if not (inspect.isclass(component) or inspect.isfunction(component)):
raise TypeError("Expect class/function type, but received {}".format(type(component))) raise TypeError(
"Expect class/function type, but received {}".format(
type(component)))
# Obtain the internal name of the component # Obtain the internal name of the component
component_name = component.__name__ component_name = component.__name__
...@@ -92,7 +95,7 @@ class ComponentManager: ...@@ -92,7 +95,7 @@ class ComponentManager:
Args: Args:
components (function | class | list | tuple): support three types of components components (function | class | list | tuple): support three types of components
Returns: Returns:
None None
""" """
...@@ -104,8 +107,11 @@ class ComponentManager: ...@@ -104,8 +107,11 @@ class ComponentManager:
else: else:
component = components component = components
self._add_single_component(component) self._add_single_component(component)
return components return components
MODELS = ComponentManager() MODELS = ComponentManager()
BACKBONES = ComponentManager() BACKBONES = ComponentManager()
\ No newline at end of file DATASETS = ComponentManager()
TRANSFORMS = ComponentManager()
...@@ -19,11 +19,14 @@ from PIL import Image ...@@ -19,11 +19,14 @@ from PIL import Image
from .dataset import Dataset from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
@manager.DATASETS.add_component
class ADE20K(Dataset): class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`. """ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args: Args:
...@@ -39,7 +42,7 @@ class ADE20K(Dataset): ...@@ -39,7 +42,7 @@ class ADE20K(Dataset):
transforms=None, transforms=None,
download=True): download=True):
self.dataset_root = dataset_root self.dataset_root = dataset_root
self.transforms = transforms self.transforms = Compose(transforms)
self.mode = mode self.mode = mode
self.file_list = list() self.file_list = list()
self.num_classes = 150 self.num_classes = 150
......
...@@ -16,8 +16,11 @@ import os ...@@ -16,8 +16,11 @@ import os
import glob import glob
from .dataset import Dataset from .dataset import Dataset
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
@manager.DATASETS.add_component
class Cityscapes(Dataset): class Cityscapes(Dataset):
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`. """Cityscapes dataset `https://www.cityscapes-dataset.com/`.
The folder structure is as follow: The folder structure is as follow:
...@@ -42,7 +45,7 @@ class Cityscapes(Dataset): ...@@ -42,7 +45,7 @@ class Cityscapes(Dataset):
def __init__(self, dataset_root, transforms=None, mode='train'): def __init__(self, dataset_root, transforms=None, mode='train'):
self.dataset_root = dataset_root self.dataset_root = dataset_root
self.transforms = transforms self.transforms = Compose(transforms)
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = 19 self.num_classes = 19
......
...@@ -17,8 +17,12 @@ import os ...@@ -17,8 +17,12 @@ import os
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
@manager.DATASETS.add_component
class Dataset(fluid.io.Dataset): class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format. """Pass in a custom dataset that conforms to the format.
...@@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset): ...@@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset):
separator=' ', separator=' ',
transforms=None): transforms=None):
self.dataset_root = dataset_root self.dataset_root = dataset_root
self.transforms = transforms self.transforms = Compose(transforms)
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = num_classes self.num_classes = num_classes
......
...@@ -16,11 +16,14 @@ import os ...@@ -16,11 +16,14 @@ import os
from .dataset import Dataset from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
@manager.DATASETS.add_component
class OpticDiscSeg(Dataset): class OpticDiscSeg(Dataset):
def __init__(self, def __init__(self,
dataset_root=None, dataset_root=None,
...@@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset): ...@@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset):
mode='train', mode='train',
download=True): download=True):
self.dataset_root = dataset_root self.dataset_root = dataset_root
self.transforms = transforms self.transforms = Compose(transforms)
self.file_list = list() self.file_list = list()
self.mode = mode self.mode = mode
self.num_classes = 2 self.num_classes = 2
......
...@@ -13,13 +13,17 @@ ...@@ -13,13 +13,17 @@
# limitations under the License. # limitations under the License.
import os import os
from .dataset import Dataset from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
@manager.DATASETS.add_component
class PascalVOC(Dataset): class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset, """Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools. please run the voc_augment.py in tools.
...@@ -36,7 +40,7 @@ class PascalVOC(Dataset): ...@@ -36,7 +40,7 @@ class PascalVOC(Dataset):
transforms=None, transforms=None,
download=True): download=True):
self.dataset_root = dataset_root self.dataset_root = dataset_root
self.transforms = transforms self.transforms = Compose(transforms)
self.mode = mode self.mode = mode
self.file_list = list() self.file_list = list()
self.num_classes = 21 self.num_classes = 21
......
...@@ -17,3 +17,4 @@ from .unet import UNet ...@@ -17,3 +17,4 @@ from .unet import UNet
from .deeplab import * from .deeplab import *
from .fcn import * from .fcn import *
from .pspnet import * from .pspnet import *
from .ocrnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph import Sequential, Conv2D
from dygraph.cvlibs import manager
from dygraph.models.architectures.layer_utils import ConvBnRelu
class SpatialGatherBlock(fluid.dygraph.Layer):
def forward(self, pixels, regions):
n, c, h, w = pixels.shape
_, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels = fluid.layers.reshape(pixels, (n, c, h * w))
pixels = fluid.layers.transpose(pixels, (0, 2, 1))
# regions: from (n, k, h, w) to (n, k, h*w)
regions = fluid.layers.reshape(regions, (n, k, h * w))
regions = fluid.layers.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1)
feats = fluid.layers.matmul(regions, pixels)
feats = fluid.layers.transpose(feats, (0, 2, 1))
feats = fluid.layers.unsqueeze(feats, axes=[-1])
return feats
class SpatialOCRModule(fluid.dygraph.Layer):
def __init__(self,
in_channels,
key_channels,
out_channels,
dropout_rate=0.1):
super(SpatialOCRModule, self).__init__()
self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.dropout_rate = dropout_rate
self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1)
def forward(self, pixels, regions):
context = self.attention_block(pixels, regions)
feats = fluid.layers.concat([context, pixels], axis=1)
feats = self.conv1x1(feats)
feats = fluid.layers.dropout(feats, self.dropout_rate)
return feats
class ObjectAttentionBlock(fluid.dygraph.Layer):
def __init__(self, in_channels, key_channels):
super(ObjectAttentionBlock, self).__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.f_pixel = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_object = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_down = ConvBnRelu(in_channels, key_channels, 1)
self.f_up = ConvBnRelu(key_channels, in_channels, 1)
def forward(self, x, proxy):
n, _, h, w = x.shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x)
query = fluid.layers.reshape(query, (n, self.key_channels, -1))
query = fluid.layers.transpose(query, (0, 2, 1))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy)
key = fluid.layers.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy)
value = fluid.layers.reshape(value, (n, self.key_channels, -1))
value = fluid.layers.transpose(value, (0, 2, 1))
# sim_map (n, h1*w1, h2*w2)
sim_map = fluid.layers.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = fluid.layers.matmul(sim_map, value)
context = fluid.layers.transpose(context, (0, 2, 1))
context = fluid.layers.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context)
return context
@manager.MODELS.add_component
class OCRNet(fluid.dygraph.Layer):
def __init__(self,
num_classes,
in_channels,
backbone,
ocr_mid_channels=512,
ocr_key_channels=256,
ignore_index=255):
super(OCRNet, self).__init__()
self.ignore_index = ignore_index
self.num_classes = num_classes
self.EPS = 1e-5
self.backbone = backbone
self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels)
self.conv3x3_ocr = ConvBnRelu(
in_channels, ocr_mid_channels, 3, padding=1)
self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = Sequential(
ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
def forward(self, x, label=None):
feats = self.backbone(x)
soft_regions = self.aux_head(feats)
pixels = self.conv3x3_ocr(feats)
object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr)
logit = fluid.layers.resize_bilinear(logit, x.shape[2:])
if self.training:
soft_regions = fluid.layers.resize_bilinear(soft_regions,
x.shape[2:])
cls_loss = self._get_loss(logit, label)
aux_loss = self._get_loss(soft_regions, label)
return cls_loss + 0.4 * aux_loss
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
...@@ -17,78 +17,36 @@ import argparse ...@@ -17,78 +17,36 @@ import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS import dygraph
import dygraph.transforms as T
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info from dygraph.utils import get_environ_info
from dygraph.utils import logger from dygraph.utils import logger
from dygraph.utils import Config
from dygraph.core import train from dygraph.core import train
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Model training') parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help='Model type for training, which is one of {}'.format(
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help="The dataset you want to train, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training # params of training
parser.add_argument( parser.add_argument(
"--input_size", "--config", dest="cfg", help="The config file.", default=None, type=str)
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument( parser.add_argument(
'--iters', '--iters',
dest='iters', dest='iters',
help='iters for training', help='iters for training',
type=int, type=int,
default=10000) default=None)
parser.add_argument( parser.add_argument(
'--batch_size', '--batch_size',
dest='batch_size', dest='batch_size',
help='Mini batch size of one gpu or cpu', help='Mini batch size of one gpu or cpu',
type=int, type=int,
default=2) default=None)
parser.add_argument( parser.add_argument(
'--learning_rate', '--learning_rate',
dest='learning_rate', dest='learning_rate',
help='Learning rate', help='Learning rate',
type=float, type=float,
default=0.01)
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrained model',
type=str,
default=None)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None) default=None)
parser.add_argument( parser.add_argument(
'--save_interval_iters', '--save_interval_iters',
...@@ -139,59 +97,28 @@ def main(args): ...@@ -139,59 +97,28 @@ def main(args):
if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \ if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
# Creat dataset reader if not args.cfg:
train_transforms = T.Compose([ raise RuntimeError('No configuration file specified.')
T.Resize(args.input_size),
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
eval_dataset = None cfg = Config(args.cfg)
if args.do_eval: train_dataset = cfg.train_dataset
eval_transforms = T.Compose( if not train_dataset:
[T.Resize(args.input_size), raise RuntimeError(
T.Normalize()]) 'The training dataset is not specified in the configuration file.'
eval_dataset = dataset( )
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
model = manager.MODELS[args.model_name](
num_classes=train_dataset.num_classes,
pretrained_model=args.pretrained_model)
# Creat optimizer val_dataset = cfg.val_dataset if args.do_eval else None
# todo, may less one than len(loader)
num_iters_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks)
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, args.iters, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train( train(
model, cfg.model,
train_dataset, train_dataset,
places=places, places=places,
eval_dataset=eval_dataset, eval_dataset=val_dataset,
optimizer=optimizer, optimizer=cfg.optimizer,
save_dir=args.save_dir, save_dir=args.save_dir,
iters=args.iters, iters=cfg.iters,
batch_size=args.batch_size, batch_size=cfg.batch_size,
resume_model=args.resume_model,
save_interval_iters=args.save_interval_iters, save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters, log_iters=args.log_iters,
num_classes=train_dataset.num_classes, num_classes=train_dataset.num_classes,
......
...@@ -21,8 +21,10 @@ from PIL import Image ...@@ -21,8 +21,10 @@ from PIL import Image
import cv2 import cv2
from .functional import * from .functional import *
from dygraph.cvlibs import manager
@manager.TRANSFORMS.add_component
class Compose: class Compose:
def __init__(self, transforms, to_rgb=True): def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list): if not isinstance(transforms, list):
...@@ -58,6 +60,7 @@ class Compose: ...@@ -58,6 +60,7 @@ class Compose:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomHorizontalFlip: class RandomHorizontalFlip:
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -73,6 +76,7 @@ class RandomHorizontalFlip: ...@@ -73,6 +76,7 @@ class RandomHorizontalFlip:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomVerticalFlip: class RandomVerticalFlip:
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -88,6 +92,7 @@ class RandomVerticalFlip: ...@@ -88,6 +92,7 @@ class RandomVerticalFlip:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Resize: class Resize:
# The interpolation mode # The interpolation mode
interp_dict = { interp_dict = {
...@@ -137,6 +142,7 @@ class Resize: ...@@ -137,6 +142,7 @@ class Resize:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeByLong: class ResizeByLong:
def __init__(self, long_size): def __init__(self, long_size):
self.long_size = long_size self.long_size = long_size
...@@ -156,6 +162,7 @@ class ResizeByLong: ...@@ -156,6 +162,7 @@ class ResizeByLong:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeRangeScaling: class ResizeRangeScaling:
def __init__(self, min_value=400, max_value=600): def __init__(self, min_value=400, max_value=600):
if min_value > max_value: if min_value > max_value:
...@@ -181,6 +188,7 @@ class ResizeRangeScaling: ...@@ -181,6 +188,7 @@ class ResizeRangeScaling:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeStepScaling: class ResizeStepScaling:
def __init__(self, def __init__(self,
min_scale_factor=0.75, min_scale_factor=0.75,
...@@ -224,6 +232,7 @@ class ResizeStepScaling: ...@@ -224,6 +232,7 @@ class ResizeStepScaling:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Normalize: class Normalize:
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean self.mean = mean
...@@ -245,6 +254,7 @@ class Normalize: ...@@ -245,6 +254,7 @@ class Normalize:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Padding: class Padding:
def __init__(self, def __init__(self,
target_size, target_size,
...@@ -305,6 +315,7 @@ class Padding: ...@@ -305,6 +315,7 @@ class Padding:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomPaddingCrop: class RandomPaddingCrop:
def __init__(self, def __init__(self,
crop_size=512, crop_size=512,
...@@ -378,6 +389,7 @@ class RandomPaddingCrop: ...@@ -378,6 +389,7 @@ class RandomPaddingCrop:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomBlur: class RandomBlur:
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -404,6 +416,7 @@ class RandomBlur: ...@@ -404,6 +416,7 @@ class RandomBlur:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomRotation: class RandomRotation:
def __init__(self, def __init__(self,
max_rotation=15, max_rotation=15,
...@@ -451,6 +464,7 @@ class RandomRotation: ...@@ -451,6 +464,7 @@ class RandomRotation:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomScaleAspect: class RandomScaleAspect:
def __init__(self, min_scale=0.5, aspect_ratio=0.33): def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale self.min_scale = min_scale
...@@ -492,6 +506,7 @@ class RandomScaleAspect: ...@@ -492,6 +506,7 @@ class RandomScaleAspect:
return (im, im_info, label) return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomDistort: class RandomDistort:
def __init__(self, def __init__(self,
brightness_range=0.5, brightness_range=0.5,
......
...@@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix ...@@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix
from .utils import * from .utils import *
from .timer import Timer, calculate_eta from .timer import Timer, calculate_eta
from .get_environ_info import get_environ_info from .get_environ_info import get_environ_info
from .config import Config
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import os
from typing import Any, Callable
import yaml
import paddle.fluid as fluid
import dygraph.cvlibs.manager as manager
class Config(object):
'''
Training config.
Args:
path(str) : the path of config file, supports yaml format only
'''
def __init__(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path))
if path.endswith('yml') or path.endswith('yaml'):
self._parse_from_yaml(path)
else:
raise RuntimeError('Config file should in yaml format!')
def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
self._build(dic)
def _build(self, dic: dict):
'''Build config from dictionary'''
dic = dic.copy()
self._batch_size = dic.get('batch_size', 1)
self._iters = dic.get('iters')
if 'model' not in dic:
raise RuntimeError()
self._model_cfg = dic['model']
self._model = None
self._train_dataset = dic.get('train_dataset')
self._val_dataset = dic.get('val_dataset')
self._learning_rate_cfg = dic.get('learning_rate', {})
self._learning_rate = self._learning_rate_cfg.get('value')
self._decay = self._learning_rate_cfg.get('decay', {
'type': 'poly',
'power': 0.9
})
self._loss_cfg = dic.get('loss', {})
self._optimizer_cfg = dic.get('optimizer', {})
def update(self,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
'''Update config'''
if learning_rate:
self._learning_rate = learning_rate
if batch_size:
self._batch_size = batch_size
if iters:
self._iters = iters
@property
def batch_size(self) -> int:
return self._batch_size
@property
def iters(self) -> int:
if not self._iters:
raise RuntimeError('No iters specified in the configuration file.')
return self._iters
@property
def learning_rate(self) -> float:
if not self._learning_rate:
raise RuntimeError(
'No learning rate specified in the configuration file.')
if self.decay_type == 'poly':
lr = self._learning_rate
args = self.decay_args
args.setdefault('decay_steps', self.iters)
return fluid.layers.polynomial_decay(lr, **args)
else:
raise RuntimeError('Only poly decay support.')
@property
def optimizer(self) -> fluid.optimizer.Optimizer:
if self.optimizer_type == 'sgd':
lr = self.learning_rate
args = self.optimizer_args
args.setdefault('momentum', 0.9)
return fluid.optimizer.Momentum(
lr, parameter_list=self.model.parameters(), **args)
else:
raise RuntimeError('Only sgd optimizer support.')
@property
def optimizer_type(self) -> str:
otype = self._optimizer_cfg.get('type')
if not otype:
raise RuntimeError(
'No optimizer type specified in the configuration file.')
return otype
@property
def optimizer_args(self) -> dict:
args = self._optimizer_cfg.copy()
args.pop('type')
return args
@property
def decay_type(self) -> str:
return self._decay['type']
@property
def decay_args(self) -> dict:
args = self._decay.copy()
args.pop('type')
return args
@property
def loss_type(self) -> str:
...
@property
def loss_args(self) -> dict:
args = self._loss_cfg.copy()
args.pop('type')
return args
@property
def model(self) -> Callable:
if not self._model:
self._model = self._load_object(self._model_cfg)
return self._model
@property
def train_dataset(self) -> Any:
if not self._train_dataset:
return None
return self._load_object(self._train_dataset)
@property
def val_dataset(self) -> Any:
if not self._val_dataset:
return None
return self._load_object(self._val_dataset)
def _load_component(self, com_name: str) -> Any:
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS
]
for com in com_list:
if com_name in com.components_dict:
return com[com_name]
else:
raise RuntimeError(
'The specified component was not found {}.'.format(com_name))
def _load_object(self, cfg: dict) -> Any:
cfg = cfg.copy()
if 'type' not in cfg:
raise RuntimeError('No object information in {}.'.format(cfg))
component = self._load_component(cfg.pop('type'))
params = {}
for key, val in cfg.items():
if self._is_meta_type(val):
params[key] = self._load_object(val)
elif isinstance(val, list):
params[key] = [
self._load_object(item)
if self._is_meta_type(item) else item for item in val
]
else:
params[key] = val
return component(**params)
def _is_meta_type(self, item: Any) -> bool:
return isinstance(item, dict) and 'type' in item
...@@ -17,48 +17,19 @@ import argparse ...@@ -17,48 +17,19 @@ import argparse
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS import dygraph
import dygraph.transforms as T
from dygraph.cvlibs import manager from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info from dygraph.utils import get_environ_info
from dygraph.utils import Config
from dygraph.core import evaluate from dygraph.core import evaluate
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation') parser = argparse.ArgumentParser(description='Model evaluation')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help='Model type for evaluation, which is one of {}'.format(
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help="The dataset you want to evaluation, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of evaluate # params of evaluate
parser.add_argument( parser.add_argument(
"--input_size", "--config", dest="cfg", help="The config file.", default=None, type=str)
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument( parser.add_argument(
'--model_dir', '--model_dir',
dest='model_dir', dest='model_dir',
...@@ -75,26 +46,21 @@ def main(args): ...@@ -75,26 +46,21 @@ def main(args):
if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \ if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \
else fluid.CPUPlace() else fluid.CPUPlace()
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places): with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) if not args.cfg:
eval_dataset = dataset( raise RuntimeError('No configuration file specified.')
dataset_root=args.dataset_root,
transforms=eval_transforms, cfg = Config(args.cfg)
mode='val') val_dataset = cfg.val_dataset
if not val_dataset:
model = manager.MODELS[args.model_name]( raise RuntimeError(
num_classes=eval_dataset.num_classes) 'The verification dataset is not specified in the configuration file.'
)
evaluate( evaluate(
model, cfg.model,
eval_dataset, val_dataset,
model_dir=args.model_dir, model_dir=args.model_dir,
num_classes=eval_dataset.num_classes) num_classes=val_dataset.num_classes)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册