提交 1a3f44a5 编写于 作者: C chenguowei01

Merge commit 'refs/pull/362/head' of https://github.com/PaddlePaddle/PaddleSeg into dygraph

......@@ -12,4 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dygraph.models
\ No newline at end of file
from . import models
from . import datasets
from . import transforms
batch_size: 2
iters: 40000
train_dataset:
type: Cityscapes
dataset_root: datasets/cityscapes
transforms:
- type: RandomHorizontalFlip
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [1024, 512]
- type: Normalize
mode: train
val_dataset:
type: Cityscapes
dataset_root: datasets/cityscapes
transforms:
- type: Normalize
mode: val
model:
type: ocrnet
backbone:
type: HRNet_W18
pretrained: dygraph/pretrained_model/hrnet_w18_ssld/model
num_classes: 19
in_channels: 270
optimizer:
type: sgd
learning_rate:
value: 0.01
decay:
type: poly
power: 0.9
loss:
type: CrossEntropy
......@@ -49,14 +49,15 @@ class ComponentManager:
return len(self._components_dict)
def __repr__(self):
return "{}:{}".format(self.__class__.__name__, list(self._components_dict.keys()))
return "{}:{}".format(self.__class__.__name__,
list(self._components_dict.keys()))
def __getitem__(self, item):
if item not in self._components_dict.keys():
raise KeyError("{} does not exist in the current {}".format(item, self))
raise KeyError("{} does not exist in the current {}".format(
item, self))
return self._components_dict[item]
@property
def components_dict(self):
return self._components_dict
......@@ -74,7 +75,9 @@ class ComponentManager:
# Currently only support class or function type
if not (inspect.isclass(component) or inspect.isfunction(component)):
raise TypeError("Expect class/function type, but received {}".format(type(component)))
raise TypeError(
"Expect class/function type, but received {}".format(
type(component)))
# Obtain the internal name of the component
component_name = component.__name__
......@@ -107,5 +110,8 @@ class ComponentManager:
return components
MODELS = ComponentManager()
BACKBONES = ComponentManager()
DATASETS = ComponentManager()
TRANSFORMS = ComponentManager()
......@@ -19,11 +19,14 @@ from PIL import Image
from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
@manager.DATASETS.add_component
class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args:
......@@ -39,7 +42,7 @@ class ADE20K(Dataset):
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.mode = mode
self.file_list = list()
self.num_classes = 150
......
......@@ -16,8 +16,11 @@ import os
import glob
from .dataset import Dataset
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
@manager.DATASETS.add_component
class Cityscapes(Dataset):
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`.
The folder structure is as follow:
......@@ -42,7 +45,7 @@ class Cityscapes(Dataset):
def __init__(self, dataset_root, transforms=None, mode='train'):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = 19
......
......@@ -17,8 +17,12 @@ import os
import paddle.fluid as fluid
import numpy as np
from PIL import Image
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
@manager.DATASETS.add_component
class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format.
......@@ -52,7 +56,7 @@ class Dataset(fluid.io.Dataset):
separator=' ',
transforms=None):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = num_classes
......
......@@ -16,11 +16,14 @@ import os
from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
@manager.DATASETS.add_component
class OpticDiscSeg(Dataset):
def __init__(self,
dataset_root=None,
......@@ -28,7 +31,7 @@ class OpticDiscSeg(Dataset):
mode='train',
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = 2
......
......@@ -13,13 +13,17 @@
# limitations under the License.
import os
from .dataset import Dataset
from dygraph.utils.download import download_file_and_uncompress
from dygraph.cvlibs import manager
from dygraph.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
@manager.DATASETS.add_component
class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
......@@ -36,7 +40,7 @@ class PascalVOC(Dataset):
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = transforms
self.transforms = Compose(transforms)
self.mode = mode
self.file_list = list()
self.num_classes = 21
......
......@@ -17,3 +17,4 @@ from .unet import UNet
from .deeplab import *
from .fcn import *
from .pspnet import *
from .ocrnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph import Sequential, Conv2D
from dygraph.cvlibs import manager
from dygraph.models.architectures.layer_utils import ConvBnRelu
class SpatialGatherBlock(fluid.dygraph.Layer):
def forward(self, pixels, regions):
n, c, h, w = pixels.shape
_, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels = fluid.layers.reshape(pixels, (n, c, h * w))
pixels = fluid.layers.transpose(pixels, (0, 2, 1))
# regions: from (n, k, h, w) to (n, k, h*w)
regions = fluid.layers.reshape(regions, (n, k, h * w))
regions = fluid.layers.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1)
feats = fluid.layers.matmul(regions, pixels)
feats = fluid.layers.transpose(feats, (0, 2, 1))
feats = fluid.layers.unsqueeze(feats, axes=[-1])
return feats
class SpatialOCRModule(fluid.dygraph.Layer):
def __init__(self,
in_channels,
key_channels,
out_channels,
dropout_rate=0.1):
super(SpatialOCRModule, self).__init__()
self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.dropout_rate = dropout_rate
self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1)
def forward(self, pixels, regions):
context = self.attention_block(pixels, regions)
feats = fluid.layers.concat([context, pixels], axis=1)
feats = self.conv1x1(feats)
feats = fluid.layers.dropout(feats, self.dropout_rate)
return feats
class ObjectAttentionBlock(fluid.dygraph.Layer):
def __init__(self, in_channels, key_channels):
super(ObjectAttentionBlock, self).__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.f_pixel = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_object = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_down = ConvBnRelu(in_channels, key_channels, 1)
self.f_up = ConvBnRelu(key_channels, in_channels, 1)
def forward(self, x, proxy):
n, _, h, w = x.shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x)
query = fluid.layers.reshape(query, (n, self.key_channels, -1))
query = fluid.layers.transpose(query, (0, 2, 1))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy)
key = fluid.layers.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy)
value = fluid.layers.reshape(value, (n, self.key_channels, -1))
value = fluid.layers.transpose(value, (0, 2, 1))
# sim_map (n, h1*w1, h2*w2)
sim_map = fluid.layers.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = fluid.layers.matmul(sim_map, value)
context = fluid.layers.transpose(context, (0, 2, 1))
context = fluid.layers.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context)
return context
@manager.MODELS.add_component
class OCRNet(fluid.dygraph.Layer):
def __init__(self,
num_classes,
in_channels,
backbone,
ocr_mid_channels=512,
ocr_key_channels=256,
ignore_index=255):
super(OCRNet, self).__init__()
self.ignore_index = ignore_index
self.num_classes = num_classes
self.EPS = 1e-5
self.backbone = backbone
self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels)
self.conv3x3_ocr = ConvBnRelu(
in_channels, ocr_mid_channels, 3, padding=1)
self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = Sequential(
ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
def forward(self, x, label=None):
feats = self.backbone(x)
soft_regions = self.aux_head(feats)
pixels = self.conv3x3_ocr(feats)
object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr)
logit = fluid.layers.resize_bilinear(logit, x.shape[2:])
if self.training:
soft_regions = fluid.layers.resize_bilinear(soft_regions,
x.shape[2:])
cls_loss = self._get_loss(logit, label)
aux_loss = self._get_loss(soft_regions, label)
return cls_loss + 0.4 * aux_loss
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
......@@ -17,78 +17,36 @@ import argparse
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
import dygraph
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import logger
from dygraph.utils import Config
from dygraph.core import train
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help='Model type for training, which is one of {}'.format(
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help="The dataset you want to train, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of training
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
"--config", dest="cfg", help="The config file.", default=None, type=str)
parser.add_argument(
'--iters',
dest='iters',
help='iters for training',
type=int,
default=10000)
default=None)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=2)
default=None)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrained model',
type=str,
default=None)
parser.add_argument(
'--resume_model',
dest='resume_model',
help='The path of resume model',
type=str,
default=None)
parser.add_argument(
'--save_interval_iters',
......@@ -139,59 +97,28 @@ def main(args):
if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \
else fluid.CPUPlace()
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
# Creat dataset reader
train_transforms = T.Compose([
T.Resize(args.input_size),
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = dataset(
dataset_root=args.dataset_root,
transforms=train_transforms,
mode='train')
if not args.cfg:
raise RuntimeError('No configuration file specified.')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
model = manager.MODELS[args.model_name](
num_classes=train_dataset.num_classes,
pretrained_model=args.pretrained_model)
cfg = Config(args.cfg)
train_dataset = cfg.train_dataset
if not train_dataset:
raise RuntimeError(
'The training dataset is not specified in the configuration file.'
)
# Creat optimizer
# todo, may less one than len(loader)
num_iters_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks)
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, args.iters, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
val_dataset = cfg.val_dataset if args.do_eval else None
train(
model,
cfg.model,
train_dataset,
places=places,
eval_dataset=eval_dataset,
optimizer=optimizer,
eval_dataset=val_dataset,
optimizer=cfg.optimizer,
save_dir=args.save_dir,
iters=args.iters,
batch_size=args.batch_size,
resume_model=args.resume_model,
iters=cfg.iters,
batch_size=cfg.batch_size,
save_interval_iters=args.save_interval_iters,
log_iters=args.log_iters,
num_classes=train_dataset.num_classes,
......
......@@ -21,8 +21,10 @@ from PIL import Image
import cv2
from .functional import *
from dygraph.cvlibs import manager
@manager.TRANSFORMS.add_component
class Compose:
def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list):
......@@ -58,6 +60,7 @@ class Compose:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomHorizontalFlip:
def __init__(self, prob=0.5):
self.prob = prob
......@@ -73,6 +76,7 @@ class RandomHorizontalFlip:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomVerticalFlip:
def __init__(self, prob=0.1):
self.prob = prob
......@@ -88,6 +92,7 @@ class RandomVerticalFlip:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Resize:
# The interpolation mode
interp_dict = {
......@@ -137,6 +142,7 @@ class Resize:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeByLong:
def __init__(self, long_size):
self.long_size = long_size
......@@ -156,6 +162,7 @@ class ResizeByLong:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeRangeScaling:
def __init__(self, min_value=400, max_value=600):
if min_value > max_value:
......@@ -181,6 +188,7 @@ class ResizeRangeScaling:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class ResizeStepScaling:
def __init__(self,
min_scale_factor=0.75,
......@@ -224,6 +232,7 @@ class ResizeStepScaling:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Normalize:
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean
......@@ -245,6 +254,7 @@ class Normalize:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class Padding:
def __init__(self,
target_size,
......@@ -305,6 +315,7 @@ class Padding:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomPaddingCrop:
def __init__(self,
crop_size=512,
......@@ -378,6 +389,7 @@ class RandomPaddingCrop:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomBlur:
def __init__(self, prob=0.1):
self.prob = prob
......@@ -404,6 +416,7 @@ class RandomBlur:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomRotation:
def __init__(self,
max_rotation=15,
......@@ -451,6 +464,7 @@ class RandomRotation:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomScaleAspect:
def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale
......@@ -492,6 +506,7 @@ class RandomScaleAspect:
return (im, im_info, label)
@manager.TRANSFORMS.add_component
class RandomDistort:
def __init__(self,
brightness_range=0.5,
......
......@@ -18,3 +18,4 @@ from .metrics import ConfusionMatrix
from .utils import *
from .timer import Timer, calculate_eta
from .get_environ_info import get_environ_info
from .config import Config
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import os
from typing import Any, Callable
import yaml
import paddle.fluid as fluid
import dygraph.cvlibs.manager as manager
class Config(object):
'''
Training config.
Args:
path(str) : the path of config file, supports yaml format only
'''
def __init__(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError('File {} does not exist'.format(path))
if path.endswith('yml') or path.endswith('yaml'):
self._parse_from_yaml(path)
else:
raise RuntimeError('Config file should in yaml format!')
def _parse_from_yaml(self, path: str):
'''Parse a yaml file and build config'''
with codecs.open(path, 'r', 'utf-8') as file:
dic = yaml.load(file, Loader=yaml.FullLoader)
self._build(dic)
def _build(self, dic: dict):
'''Build config from dictionary'''
dic = dic.copy()
self._batch_size = dic.get('batch_size', 1)
self._iters = dic.get('iters')
if 'model' not in dic:
raise RuntimeError()
self._model_cfg = dic['model']
self._model = None
self._train_dataset = dic.get('train_dataset')
self._val_dataset = dic.get('val_dataset')
self._learning_rate_cfg = dic.get('learning_rate', {})
self._learning_rate = self._learning_rate_cfg.get('value')
self._decay = self._learning_rate_cfg.get('decay', {
'type': 'poly',
'power': 0.9
})
self._loss_cfg = dic.get('loss', {})
self._optimizer_cfg = dic.get('optimizer', {})
def update(self,
learning_rate: float = None,
batch_size: int = None,
iters: int = None):
'''Update config'''
if learning_rate:
self._learning_rate = learning_rate
if batch_size:
self._batch_size = batch_size
if iters:
self._iters = iters
@property
def batch_size(self) -> int:
return self._batch_size
@property
def iters(self) -> int:
if not self._iters:
raise RuntimeError('No iters specified in the configuration file.')
return self._iters
@property
def learning_rate(self) -> float:
if not self._learning_rate:
raise RuntimeError(
'No learning rate specified in the configuration file.')
if self.decay_type == 'poly':
lr = self._learning_rate
args = self.decay_args
args.setdefault('decay_steps', self.iters)
return fluid.layers.polynomial_decay(lr, **args)
else:
raise RuntimeError('Only poly decay support.')
@property
def optimizer(self) -> fluid.optimizer.Optimizer:
if self.optimizer_type == 'sgd':
lr = self.learning_rate
args = self.optimizer_args
args.setdefault('momentum', 0.9)
return fluid.optimizer.Momentum(
lr, parameter_list=self.model.parameters(), **args)
else:
raise RuntimeError('Only sgd optimizer support.')
@property
def optimizer_type(self) -> str:
otype = self._optimizer_cfg.get('type')
if not otype:
raise RuntimeError(
'No optimizer type specified in the configuration file.')
return otype
@property
def optimizer_args(self) -> dict:
args = self._optimizer_cfg.copy()
args.pop('type')
return args
@property
def decay_type(self) -> str:
return self._decay['type']
@property
def decay_args(self) -> dict:
args = self._decay.copy()
args.pop('type')
return args
@property
def loss_type(self) -> str:
...
@property
def loss_args(self) -> dict:
args = self._loss_cfg.copy()
args.pop('type')
return args
@property
def model(self) -> Callable:
if not self._model:
self._model = self._load_object(self._model_cfg)
return self._model
@property
def train_dataset(self) -> Any:
if not self._train_dataset:
return None
return self._load_object(self._train_dataset)
@property
def val_dataset(self) -> Any:
if not self._val_dataset:
return None
return self._load_object(self._val_dataset)
def _load_component(self, com_name: str) -> Any:
com_list = [
manager.MODELS, manager.BACKBONES, manager.DATASETS,
manager.TRANSFORMS
]
for com in com_list:
if com_name in com.components_dict:
return com[com_name]
else:
raise RuntimeError(
'The specified component was not found {}.'.format(com_name))
def _load_object(self, cfg: dict) -> Any:
cfg = cfg.copy()
if 'type' not in cfg:
raise RuntimeError('No object information in {}.'.format(cfg))
component = self._load_component(cfg.pop('type'))
params = {}
for key, val in cfg.items():
if self._is_meta_type(val):
params[key] = self._load_object(val)
elif isinstance(val, list):
params[key] = [
self._load_object(item)
if self._is_meta_type(item) else item for item in val
]
else:
params[key] = val
return component(**params)
def _is_meta_type(self, item: Any) -> bool:
return isinstance(item, dict) and 'type' in item
......@@ -17,48 +17,19 @@ import argparse
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
import dygraph
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.utils import Config
from dygraph.core import evaluate
def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help='Model type for evaluation, which is one of {}'.format(
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help="The dataset you want to evaluation, which is one of {}".format(
str(list(DATASETS.keys()))),
type=str,
default='OpticDiscSeg')
parser.add_argument(
'--dataset_root',
dest='dataset_root',
help="dataset root directory",
type=str,
default=None)
# params of evaluate
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
"--config", dest="cfg", help="The config file.", default=None, type=str)
parser.add_argument(
'--model_dir',
dest='model_dir',
......@@ -75,26 +46,21 @@ def main(args):
if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] \
else fluid.CPUPlace()
if args.dataset not in DATASETS:
raise Exception('`--dataset` is invalid. it should be one of {}'.format(
str(list(DATASETS.keys()))))
dataset = DATASETS[args.dataset]
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(
dataset_root=args.dataset_root,
transforms=eval_transforms,
mode='val')
model = manager.MODELS[args.model_name](
num_classes=eval_dataset.num_classes)
if not args.cfg:
raise RuntimeError('No configuration file specified.')
cfg = Config(args.cfg)
val_dataset = cfg.val_dataset
if not val_dataset:
raise RuntimeError(
'The verification dataset is not specified in the configuration file.'
)
evaluate(
model,
eval_dataset,
cfg.model,
val_dataset,
model_dir=args.model_dir,
num_classes=eval_dataset.num_classes)
num_classes=val_dataset.num_classes)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册