config.py 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import os
from typing import Any, Callable

import yaml
C
chenguowei01 已提交
20 21
import paddle
import paddle.nn.functional as F
22 23

import paddleseg.cvlibs.manager as manager
C
chenguowei01 已提交
24
from paddleseg.utils import logger
25 26 27 28 29 30 31 32 33 34 35 36 37 38


class Config(object):
    '''
    Training config.

    Args:
        path(str) : the path of config file, supports yaml format only
    '''

    def __init__(self, path: str):
        if not os.path.exists(path):
            raise FileNotFoundError('File {} does not exist'.format(path))

W
wuzewu 已提交
39 40
        self._model = None
        self._losses = None
41
        if path.endswith('yml') or path.endswith('yaml'):
W
wuzewu 已提交
42
            self.dic = self._parse_from_yaml(path)
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        else:
            raise RuntimeError('Config file should in yaml format!')

    def _update_dic(self, dic, base_dic):
        """
        update config from dic based base_dic
        """
        base_dic = base_dic.copy()
        for key, val in dic.items():
            if isinstance(val, dict) and key in base_dic:
                base_dic[key] = self._update_dic(val, base_dic[key])
            else:
                base_dic[key] = val
        dic = base_dic
        return dic

    def _parse_from_yaml(self, path: str):
        '''Parse a yaml file and build config'''
        with codecs.open(path, 'r', 'utf-8') as file:
            dic = yaml.load(file, Loader=yaml.FullLoader)
W
wuzewu 已提交
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77
        if '_base_' in dic:
            cfg_dir = os.path.dirname(path)
            base_path = dic.pop('_base_')
            base_path = os.path.join(cfg_dir, base_path)
            base_dic = self._parse_from_yaml(base_path)
            dic = self._update_dic(dic, base_dic)
        return dic

    def update(self,
               learning_rate: float = None,
               batch_size: int = None,
               iters: int = None):
        '''Update config'''
        if learning_rate:
W
wuzewu 已提交
78
            self.dic['learning_rate']['value'] = learning_rate
79 80

        if batch_size:
W
wuzewu 已提交
81
            self.dic['batch_size'] = batch_size
82 83

        if iters:
W
wuzewu 已提交
84
            self.dic['iters'] = iters
85 86 87

    @property
    def batch_size(self) -> int:
W
wuzewu 已提交
88
        return self.dic.get('batch_size', 1)
89 90 91

    @property
    def iters(self) -> int:
W
wuzewu 已提交
92 93
        iters = self.dic.get('iters')
        if not iters:
94
            raise RuntimeError('No iters specified in the configuration file.')
W
wuzewu 已提交
95
        return iters
96 97 98

    @property
    def learning_rate(self) -> float:
W
wuzewu 已提交
99 100
        _learning_rate = self.dic.get('learning_rate', {}).get('value')
        if not _learning_rate:
101 102 103
            raise RuntimeError(
                'No learning rate specified in the configuration file.')

W
wuzewu 已提交
104 105 106 107 108
        args = self.decay_args
        decay_type = args.pop('type')

        if decay_type == 'poly':
            lr = _learning_rate
C
chenguowei01 已提交
109
            return paddle.optimizer.PolynomialLR(lr, **args)
110 111 112 113
        else:
            raise RuntimeError('Only poly decay support.')

    @property
C
chenguowei01 已提交
114
    def optimizer(self) -> paddle.optimizer.Optimizer:
W
wuzewu 已提交
115 116 117 118
        args = self.optimizer_args
        optimizer_type = args.pop('type')

        if optimizer_type == 'sgd':
119
            lr = self.learning_rate
C
chenguowei01 已提交
120 121
            return paddle.optimizer.Momentum(
                lr, parameters=self.model.parameters(), **args)
122 123 124 125 126
        else:
            raise RuntimeError('Only sgd optimizer support.')

    @property
    def optimizer_args(self) -> dict:
W
wuzewu 已提交
127 128 129
        args = self.dic.get('optimizer', {}).copy()
        if args['type'] == 'sgd':
            args.setdefault('momentum', 0.9)
130

W
wuzewu 已提交
131
        return args
132 133 134

    @property
    def decay_args(self) -> dict:
W
wuzewu 已提交
135 136 137 138 139 140 141 142 143
        args = self.dic.get('learning_rate', {}).get('decay', {
            'type': 'poly',
            'power': 0.9
        }).copy()

        if args['type'] == 'poly':
            args.setdefault('decay_steps', self.iters)
            args.setdefault('end_lr', 0)

144 145 146 147
        return args

    @property
    def loss(self) -> list:
W
wuzewu 已提交
148 149
        args = self.dic.get('loss', {}).copy()

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        if not self._losses:
            self._losses = dict()
            for key, val in args.items():
                if key == 'types':
                    self._losses['types'] = []
                    for item in args['types']:
                        self._losses['types'].append(self._load_object(item))
                else:
                    self._losses[key] = val
            if len(self._losses['coef']) != len(self._losses['types']):
                raise RuntimeError(
                    'The length of coef should equal to types in loss config: {} != {}.'
                    .format(
                        len(self._losses['coef']), len(self._losses['types'])))
        return self._losses

    @property
    def model(self) -> Callable:
W
wuzewu 已提交
168 169 170
        model_cfg = self.dic.get('model').copy()
        if not model_cfg:
            raise RuntimeError('No model specified in the configuration file.')
171
        if not self._model:
W
wuzewu 已提交
172
            self._model = self._load_object(model_cfg)
173 174 175 176
        return self._model

    @property
    def train_dataset(self) -> Any:
W
wuzewu 已提交
177 178
        _train_dataset = self.dic.get('train_dataset').copy()
        if not _train_dataset:
179
            return None
W
wuzewu 已提交
180
        return self._load_object(_train_dataset)
181 182 183

    @property
    def val_dataset(self) -> Any:
W
wuzewu 已提交
184 185
        _val_dataset = self.dic.get('val_dataset').copy()
        if not _val_dataset:
186
            return None
W
wuzewu 已提交
187
        return self._load_object(_val_dataset)
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224

    def _load_component(self, com_name: str) -> Any:
        com_list = [
            manager.MODELS, manager.BACKBONES, manager.DATASETS,
            manager.TRANSFORMS, manager.LOSSES
        ]

        for com in com_list:
            if com_name in com.components_dict:
                return com[com_name]
        else:
            raise RuntimeError(
                'The specified component was not found {}.'.format(com_name))

    def _load_object(self, cfg: dict) -> Any:
        cfg = cfg.copy()
        if 'type' not in cfg:
            raise RuntimeError('No object information in {}.'.format(cfg))

        component = self._load_component(cfg.pop('type'))

        params = {}
        for key, val in cfg.items():
            if self._is_meta_type(val):
                params[key] = self._load_object(val)
            elif isinstance(val, list):
                params[key] = [
                    self._load_object(item)
                    if self._is_meta_type(item) else item for item in val
                ]
            else:
                params[key] = val

        return component(**params)

    def _is_meta_type(self, item: Any) -> bool:
        return isinstance(item, dict) and 'type' in item
W
wuzewu 已提交
225 226 227

    def __str__(self) -> str:
        return yaml.dump(self.dic)