meta_arch.py 2.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle
import paddle.nn as nn
from ppdet.core.workspace import register

__all__ = ['BaseArch']


@register
class BaseArch(nn.Layer):
14
    def __init__(self, data_format='NCHW'):
Q
qingqing01 已提交
15
        super(BaseArch, self).__init__()
16
        self.data_format = data_format
17 18 19 20 21 22 23 24 25 26 27 28 29 30
        self.inputs = {}
        self.fuse_norm = False

    def load_meanstd(self, cfg_transform):
        self.scale = 1.
        self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape(
            (1, 3, 1, 1))
        self.std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))
        for item in cfg_transform:
            if 'NormalizeImage' in item:
                self.mean = paddle.to_tensor(item['NormalizeImage'][
                    'mean']).reshape((1, 3, 1, 1))
                self.std = paddle.to_tensor(item['NormalizeImage'][
                    'std']).reshape((1, 3, 1, 1))
W
wangguanzhong 已提交
31
                if item['NormalizeImage'].get('is_scale', True):
32 33 34 35 36
                    self.scale = 1. / 255.
                break
        if self.data_format == 'NHWC':
            self.mean = self.mean.reshape(1, 1, 1, 3)
            self.std = self.std.reshape(1, 1, 1, 3)
Q
qingqing01 已提交
37

38
    def forward(self, inputs):
39 40 41
        if self.data_format == 'NHWC':
            image = inputs['image']
            inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
42 43 44 45 46 47 48 49 50

        if self.fuse_norm:
            image = inputs['image']
            self.inputs['image'] = (image * self.scale - self.mean) / self.std
            self.inputs['im_shape'] = inputs['im_shape']
            self.inputs['scale_factor'] = inputs['scale_factor']
        else:
            self.inputs = inputs

Q
qingqing01 已提交
51 52
        self.model_arch()

53
        if self.training:
Q
qingqing01 已提交
54 55
            out = self.get_loss()
        else:
56
            out = self.get_pred()
Q
qingqing01 已提交
57 58 59 60 61 62 63 64
        return out

    def build_inputs(self, data, input_def):
        inputs = {}
        for i, k in enumerate(input_def):
            inputs[k] = data[i]
        return inputs

65 66
    def model_arch(self, ):
        pass
Q
qingqing01 已提交
67 68 69 70 71 72

    def get_loss(self, ):
        raise NotImplementedError("Should implement get_loss method!")

    def get_pred(self, ):
        raise NotImplementedError("Should implement get_pred method!")