trainer.py 13.8 KB
Newer Older
F
Feng Ni 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
K
Kaipeng Deng 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import random
import datetime
import numpy as np
from PIL import Image

import paddle
from paddle.distributed import ParallelEnv
from paddle.static import InputSpec

from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, get_categories, get_infer_results
import ppdet.utils.stats as stats

from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer
from .export_utils import _dump_infer_config

from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)

__all__ = ['Trainer']


class Trainer(object):
    def __init__(self, cfg, mode='train'):
        self.cfg = cfg
        assert mode.lower() in ['train', 'eval', 'test'], \
                "mode should be 'train', 'eval' or 'test'"
        self.mode = mode.lower()

        # build model
        self.model = create(cfg.architecture)
        if ParallelEnv().nranks > 1:
            self.model = paddle.DataParallel(self.model)

        # build data loader
        self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
        # TestDataset build after user set images, skip loader creation here
        if self.mode != 'test':
            self.loader = create('{}Reader'.format(self.mode.capitalize()))(
                self.dataset, cfg.worker_num)

        # build optimizer in train mode
        self.optimizer = None
        if self.mode == 'train':
            steps_per_epoch = len(self.loader)
            self.lr = create('LearningRate')(steps_per_epoch)
            self.optimizer = create('OptimizerBuilder')(self.lr,
                                                        self.model.parameters())

        self.status = {}

        self.start_epoch = 0
        self.end_epoch = cfg.epoch

        self._weights_loaded = False

        # initial default callbacks
        self._init_callbacks()

        # initial default metrics
        self._init_metrics()
        self._reset_metrics()

    def _init_callbacks(self):
        if self.mode == 'train':
            self._callbacks = [LogPrinter(self), Checkpointer(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
        elif self.mode == 'eval':
            self._callbacks = [LogPrinter(self)]
            self._compose_callback = ComposeCallback(self._callbacks)
        else:
            self._callbacks = []
            self._compose_callback = None

    def _init_metrics(self):
        if self.mode == 'eval':
            if self.cfg.metric == 'COCO':
W
wangguanzhong 已提交
100 101
                mask_resolution = self.model.mask_post_process.mask_resolution if getattr(
                    self.model, 'mask_post_process', None) else None
K
Kaipeng Deng 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
                self._metrics = [
                    COCOMetric(
                        anno_file=self.dataset.get_anno(),
                        with_background=self.cfg.with_background,
                        mask_resolution=mask_resolution)
                ]
            elif self.cfg.metric == 'VOC':
                self._metrics = [
                    VOCMetric(
                        anno_file=self.dataset.get_anno(),
                        with_background=self.cfg.with_background,
                        class_num=self.cfg.num_classes,
                        map_type=self.cfg.map_type)
                ]
            else:
                logger.warn("Metric not support for metric type {}".format(
                    self.cfg.metric))
                self._metrics = []
        else:
            self._metrics = []

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def register_callbacks(self, callbacks):
        callbacks = [h for h in list(callbacks) if h is not None]
        for c in callbacks:
            assert isinstance(c, Callback), \
                    "metrics shoule be instances of subclass of Metric"
        self._callbacks.extend(callbacks)
        self._compose_callback = ComposeCallback(self._callbacks)

    def register_metrics(self, metrics):
        metrics = [m for m in list(metrics) if m is not None]
        for m in metrics:
            assert isinstance(m, Metric), \
                    "metrics shoule be instances of subclass of Metric"
        self._metrics.extend(metrics)

    def load_weights(self, weights, weight_type='pretrain'):
        assert weight_type in ['pretrain', 'resume', 'finetune'], \
                "weight_type can only be 'pretrain', 'resume', 'finetune'"
        if weight_type == 'resume':
            self.start_epoch = load_weight(self.model, weights, self.optimizer)
            logger.debug("Resume weights of epoch {}".format(self.start_epoch))
        else:
            self.start_epoch = 0
            load_pretrain_weight(self.model, weights,
                                 self.cfg.get('load_static_weights', False),
                                 weight_type)
            logger.debug("Load {} weights {} to start training".format(
                weight_type, weights))
        self._weights_loaded = True

    def train(self):
        assert self.mode == 'train', "Model not in 'train' mode"
159
        self.model.train()
K
Kaipeng Deng 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203

        # if no given weights loaded, load backbone pretrain weights as default
        if not self._weights_loaded:
            self.load_weights(self.cfg.pretrain_weights)

        self.status.update({
            'epoch_id': self.start_epoch,
            'step_id': 0,
            'steps_per_epoch': len(self.loader)
        })

        self.status['batch_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['data_time'] = stats.SmoothedValue(
            self.cfg.log_iter, fmt='{avg:.4f}')
        self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)

        for epoch_id in range(self.start_epoch, self.cfg.epoch):
            self.status['epoch_id'] = epoch_id
            self._compose_callback.on_epoch_begin(self.status)
            self.loader.dataset.set_epoch(epoch_id)
            iter_tic = time.time()
            for step_id, data in enumerate(self.loader):
                self.status['data_time'].update(time.time() - iter_tic)
                self.status['step_id'] = step_id
                self._compose_callback.on_step_begin(self.status)

                # model forward
                outputs = self.model(data)
                loss = outputs['loss']

                # model backward
                loss.backward()
                self.optimizer.step()
                curr_lr = self.optimizer.get_lr()
                self.lr.step()
                self.optimizer.clear_grad()
                self.status['learning_rate'] = curr_lr

                if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
                    self.status['training_staus'].update(outputs)

                self.status['batch_time'].update(time.time() - iter_tic)
                self._compose_callback.on_step_end(self.status)
F
Feng Ni 已提交
204
                iter_tic = time.time()
K
Kaipeng Deng 已提交
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
            self._compose_callback.on_epoch_end(self.status)

    def evaluate(self):
        sample_num = 0
        tic = time.time()
        self._compose_callback.on_epoch_begin(self.status)
        for step_id, data in enumerate(self.loader):
            self.status['step_id'] = step_id
            self._compose_callback.on_step_begin(self.status)
            # forward
            self.model.eval()
            outs = self.model(data)

            # update metrics
            for metric in self._metrics:
                metric.update(data, outs)

            sample_num += data['im_id'].numpy().shape[0]
            self._compose_callback.on_step_end(self.status)

        self.status['sample_num'] = sample_num
        self.status['cost_time'] = time.time() - tic
        self._compose_callback.on_epoch_end(self.status)

        # accumulate metric to log out
        for metric in self._metrics:
            metric.accumulate()
            metric.log()
        # reset metric states for metric may performed multiple times
        self._reset_metrics()

    def predict(self, images, draw_threshold=0.5, output_dir='output'):
        self.dataset.set_images(images)
        loader = create('TestReader')(self.dataset, 0)

        imid2path = self.dataset.get_imid2path()

        anno_file = self.dataset.get_anno()
        with_background = self.cfg.with_background
        clsid2catid, catid2name = get_categories(self.cfg.metric, anno_file,
                                                 with_background)

F
Feng Ni 已提交
247
        # Run Infer
K
Kaipeng Deng 已提交
248 249 250 251 252 253 254
        for step_id, data in enumerate(loader):
            self.status['step_id'] = step_id
            # forward
            self.model.eval()
            outs = self.model(data)
            for key in ['im_shape', 'scale_factor', 'im_id']:
                outs[key] = data[key]
G
Guanghua Yu 已提交
255 256
            for key, value in outs.items():
                outs[key] = value.numpy()
K
Kaipeng Deng 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277

            # FIXME: for more elegent coding
            if 'mask' in outs and 'bbox' in outs:
                mask_resolution = self.model.mask_post_process.mask_resolution
                from ppdet.py_op.post_process import mask_post_process
                outs['mask'] = mask_post_process(outs, outs['im_shape'],
                                                 outs['scale_factor'],
                                                 mask_resolution)

            batch_res = get_infer_results(outs, clsid2catid)
            bbox_num = outs['bbox_num']
            start = 0
            for i, im_id in enumerate(outs['im_id']):
                image_path = imid2path[int(im_id)]
                image = Image.open(image_path).convert('RGB')
                end = start + bbox_num[i]

                bbox_res = batch_res['bbox'][start:end] \
                        if 'bbox' in batch_res else None
                mask_res = batch_res['mask'][start:end] \
                        if 'mask' in batch_res else None
G
Guanghua Yu 已提交
278 279 280
                segm_res = batch_res['segm'][start:end] \
                        if 'segm' in batch_res else None
                image = visualize_results(image, bbox_res, mask_res, segm_res,
K
Kaipeng Deng 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
                                          int(outs['im_id']), catid2name,
                                          draw_threshold)

                # save image with detection
                save_name = self._get_save_image_name(output_dir, image_path)
                logger.info("Detection bbox results save in {}".format(
                    save_name))
                image.save(save_name, quality=95)
                start = end

    def _get_save_image_name(self, output_dir, image_path):
        """
        Get save image name from source image path.
        """
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        image_name = os.path.split(image_path)[-1]
        name, ext = os.path.splitext(image_name)
        return os.path.join(output_dir, "{}".format(name)) + ext

    def export(self, output_dir='output_inference'):
302
        self.model.eval()
K
Kaipeng Deng 已提交
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
        model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
        save_dir = os.path.join(output_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        image_shape = None
        if 'inputs_def' in self.cfg['TestReader']:
            inputs_def = self.cfg['TestReader']['inputs_def']
            image_shape = inputs_def.get('image_shape', None)
        if image_shape is None:
            image_shape = [3, None, None]

        # Save infer cfg
        _dump_infer_config(self.cfg,
                           os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
                           self.model)

        input_spec = [{
            "image": InputSpec(
                shape=[None] + image_shape, name='image'),
            "im_shape": InputSpec(
                shape=[None, 2], name='im_shape'),
            "scale_factor": InputSpec(
                shape=[None, 2], name='scale_factor')
        }]

        # dy2st and save model
        static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
330 331 332 333 334 335 336 337 338
        # NOTE: dy2st do not pruned program, but jit.save will prune program
        # input spec, prune input spec here and save with pruned input spec
        pruned_input_spec = self._prune_input_spec(
            input_spec, static_model.forward.main_program,
            static_model.forward.outputs)
        paddle.jit.save(
            static_model,
            os.path.join(save_dir, 'model'),
            input_spec=pruned_input_spec)
K
Kaipeng Deng 已提交
339
        logger.info("Export model and saved in {}".format(save_dir))
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356

    def _prune_input_spec(self, input_spec, program, targets):
        # try to prune static program to figure out pruned input spec
        # so we perform following operations in static mode
        paddle.enable_static()
        pruned_input_spec = [{}]
        program = program.clone()
        program = program._prune(targets=targets)
        global_block = program.global_block()
        for name, spec in input_spec[0].items():
            try:
                v = global_block.var(name)
                pruned_input_spec[0][name] = spec
            except Exception:
                pass
        paddle.disable_static()
        return pruned_input_spec