meta_arch.py 4.1 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

M
Mark Ma 已提交
5
import numpy as np
Q
qingqing01 已提交
6 7
import paddle
import paddle.nn as nn
M
Mark Ma 已提交
8 9
import typing

Q
qingqing01 已提交
10
from ppdet.core.workspace import register
M
Mark Ma 已提交
11
from ppdet.modeling.post_process import nms
Q
qingqing01 已提交
12 13 14 15 16 17

__all__ = ['BaseArch']


@register
class BaseArch(nn.Layer):
18
    def __init__(self, data_format='NCHW'):
Q
qingqing01 已提交
19
        super(BaseArch, self).__init__()
20
        self.data_format = data_format
21 22 23 24 25 26 27 28 29 30 31 32 33 34
        self.inputs = {}
        self.fuse_norm = False

    def load_meanstd(self, cfg_transform):
        self.scale = 1.
        self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape(
            (1, 3, 1, 1))
        self.std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
        for item in cfg_transform:
            if 'NormalizeImage' in item:
                self.mean = paddle.to_tensor(item['NormalizeImage'][
                    'mean']).reshape((1, 3, 1, 1))
                self.std = paddle.to_tensor(item['NormalizeImage'][
                    'std']).reshape((1, 3, 1, 1))
W
wangguanzhong 已提交
35
                if item['NormalizeImage'].get('is_scale', True):
36 37 38 39 40
                    self.scale = 1. / 255.
                break
        if self.data_format == 'NHWC':
            self.mean = self.mean.reshape(1, 1, 1, 3)
            self.std = self.std.reshape(1, 1, 1, 3)
Q
qingqing01 已提交
41

42
    def forward(self, inputs):
43 44 45
        if self.data_format == 'NHWC':
            image = inputs['image']
            inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
46 47 48 49 50 51 52 53 54

        if self.fuse_norm:
            image = inputs['image']
            self.inputs['image'] = (image * self.scale - self.mean) / self.std
            self.inputs['im_shape'] = inputs['im_shape']
            self.inputs['scale_factor'] = inputs['scale_factor']
        else:
            self.inputs = inputs

Q
qingqing01 已提交
55 56
        self.model_arch()

57
        if self.training:
Q
qingqing01 已提交
58 59
            out = self.get_loss()
        else:
M
Mark Ma 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            inputs_list = []
            # multi-scale input
            if not isinstance(inputs, typing.Sequence):
                inputs_list.append(inputs)
            else:
                inputs_list.extend(inputs)

            outs = []
            for inp in inputs_list:
                self.inputs = inp
                outs.append(self.get_pred())

            # multi-scale test
            if len(outs)>1:
                out = self.merge_multi_scale_predictions(outs)
            else:
                out = outs[0]
        return out

    def merge_multi_scale_predictions(self, outs):
        # default values for architectures not included in following list
        num_classes = 80
        nms_threshold = 0.5
        keep_top_k = 100

        if self.__class__.__name__ in ('CascadeRCNN', 'FasterRCNN', 'MaskRCNN'):
            num_classes = self.bbox_head.num_classes
            keep_top_k = self.bbox_post_process.nms.keep_top_k
            nms_threshold = self.bbox_post_process.nms.nms_threshold
        else:
            raise Exception("Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now")

        final_boxes = []
        all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy()
        for c in range(num_classes):
            idxs = all_scale_outs[:, 0] == c
            if np.count_nonzero(idxs) == 0:
                continue
            r = nms(all_scale_outs[idxs, 1:], nms_threshold)
            final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
        out = np.concatenate(final_boxes)
        out = np.concatenate(sorted(out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6))
        out = {
            'bbox': paddle.to_tensor(out),
            'bbox_num': paddle.to_tensor(np.array([out.shape[0], ]))
        }

Q
qingqing01 已提交
107 108 109 110 111 112 113 114
        return out

    def build_inputs(self, data, input_def):
        inputs = {}
        for i, k in enumerate(input_def):
            inputs[k] = data[i]
        return inputs

115 116
    def model_arch(self, ):
        pass
Q
qingqing01 已提交
117 118 119 120 121 122

    def get_loss(self, ):
        raise NotImplementedError("Should implement get_loss method!")

    def get_pred(self, ):
        raise NotImplementedError("Should implement get_pred method!")