meta_arch.py 4.5 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

M
Mark Ma 已提交
5
import numpy as np
Q
qingqing01 已提交
6 7
import paddle
import paddle.nn as nn
M
Mark Ma 已提交
8 9
import typing

Q
qingqing01 已提交
10
from ppdet.core.workspace import register
M
Mark Ma 已提交
11
from ppdet.modeling.post_process import nms
Q
qingqing01 已提交
12 13 14 15 16 17

__all__ = ['BaseArch']


@register
class BaseArch(nn.Layer):
18
    def __init__(self, data_format='NCHW'):
Q
qingqing01 已提交
19
        super(BaseArch, self).__init__()
20
        self.data_format = data_format
21 22 23 24
        self.inputs = {}
        self.fuse_norm = False

    def load_meanstd(self, cfg_transform):
W
wangxinxin08 已提交
25 26 27
        scale = 1.
        mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
28 29
        for item in cfg_transform:
            if 'NormalizeImage' in item:
W
wangxinxin08 已提交
30 31 32
                mean = np.array(
                    item['NormalizeImage']['mean'], dtype=np.float32)
                std = np.array(item['NormalizeImage']['std'], dtype=np.float32)
W
wangguanzhong 已提交
33
                if item['NormalizeImage'].get('is_scale', True):
W
wangxinxin08 已提交
34
                    scale = 1. / 255.
35 36
                break
        if self.data_format == 'NHWC':
W
wangxinxin08 已提交
37 38 39 40 41
            self.scale = paddle.to_tensor(scale / std).reshape((1, 1, 1, 3))
            self.bias = paddle.to_tensor(-mean / std).reshape((1, 1, 1, 3))
        else:
            self.scale = paddle.to_tensor(scale / std).reshape((1, 3, 1, 1))
            self.bias = paddle.to_tensor(-mean / std).reshape((1, 3, 1, 1))
Q
qingqing01 已提交
42

43
    def forward(self, inputs):
44 45 46
        if self.data_format == 'NHWC':
            image = inputs['image']
            inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
47 48 49

        if self.fuse_norm:
            image = inputs['image']
W
wangxinxin08 已提交
50
            self.inputs['image'] = image * self.scale + self.bias
51 52 53 54 55
            self.inputs['im_shape'] = inputs['im_shape']
            self.inputs['scale_factor'] = inputs['scale_factor']
        else:
            self.inputs = inputs

Q
qingqing01 已提交
56 57
        self.model_arch()

58
        if self.training:
Q
qingqing01 已提交
59 60
            out = self.get_loss()
        else:
M
Mark Ma 已提交
61 62 63 64 65 66 67 68
            inputs_list = []
            # multi-scale input
            if not isinstance(inputs, typing.Sequence):
                inputs_list.append(inputs)
            else:
                inputs_list.extend(inputs)
            outs = []
            for inp in inputs_list:
69
                if self.fuse_norm:
W
wangxinxin08 已提交
70
                    self.inputs['image'] = inp['image'] * self.scale + self.bias
71 72 73 74
                    self.inputs['im_shape'] = inp['im_shape']
                    self.inputs['scale_factor'] = inp['scale_factor']
                else:
                    self.inputs = inp
M
Mark Ma 已提交
75 76 77
                outs.append(self.get_pred())

            # multi-scale test
W
wangxinxin08 已提交
78
            if len(outs) > 1:
M
Mark Ma 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
                out = self.merge_multi_scale_predictions(outs)
            else:
                out = outs[0]
        return out

    def merge_multi_scale_predictions(self, outs):
        # default values for architectures not included in following list
        num_classes = 80
        nms_threshold = 0.5
        keep_top_k = 100

        if self.__class__.__name__ in ('CascadeRCNN', 'FasterRCNN', 'MaskRCNN'):
            num_classes = self.bbox_head.num_classes
            keep_top_k = self.bbox_post_process.nms.keep_top_k
            nms_threshold = self.bbox_post_process.nms.nms_threshold
        else:
W
wangxinxin08 已提交
95 96 97
            raise Exception(
                "Multi scale test only supports CascadeRCNN, FasterRCNN and MaskRCNN for now"
            )
M
Mark Ma 已提交
98 99 100 101 102 103 104 105

        final_boxes = []
        all_scale_outs = paddle.concat([o['bbox'] for o in outs]).numpy()
        for c in range(num_classes):
            idxs = all_scale_outs[:, 0] == c
            if np.count_nonzero(idxs) == 0:
                continue
            r = nms(all_scale_outs[idxs, 1:], nms_threshold)
W
wangxinxin08 已提交
106 107
            final_boxes.append(
                np.concatenate([np.full((r.shape[0], 1), c), r], 1))
M
Mark Ma 已提交
108
        out = np.concatenate(final_boxes)
W
wangxinxin08 已提交
109 110
        out = np.concatenate(sorted(
            out, key=lambda e: e[1])[-keep_top_k:]).reshape((-1, 6))
M
Mark Ma 已提交
111 112 113 114 115
        out = {
            'bbox': paddle.to_tensor(out),
            'bbox_num': paddle.to_tensor(np.array([out.shape[0], ]))
        }

Q
qingqing01 已提交
116 117 118 119 120 121 122 123
        return out

    def build_inputs(self, data, input_def):
        inputs = {}
        for i, k in enumerate(input_def):
            inputs[k] = data[i]
        return inputs

124 125
    def model_arch(self, ):
        pass
Q
qingqing01 已提交
126 127 128 129 130 131

    def get_loss(self, ):
        raise NotImplementedError("Should implement get_loss method!")

    def get_pred(self, ):
        raise NotImplementedError("Should implement get_pred method!")