meta_arch.py 1.1 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import paddle
import paddle.nn as nn
from ppdet.core.workspace import register

__all__ = ['BaseArch']


@register
class BaseArch(nn.Layer):
15
    def __init__(self, data_format='NCHW'):
Q
qingqing01 已提交
16
        super(BaseArch, self).__init__()
17
        self.data_format = data_format
Q
qingqing01 已提交
18

19
    def forward(self, inputs):
20 21 22
        if self.data_format == 'NHWC':
            image = inputs['image']
            inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
Q
qingqing01 已提交
23 24 25
        self.inputs = inputs
        self.model_arch()

26
        if self.training:
Q
qingqing01 已提交
27 28
            out = self.get_loss()
        else:
29
            out = self.get_pred()
Q
qingqing01 已提交
30 31 32 33 34 35 36 37
        return out

    def build_inputs(self, data, input_def):
        inputs = {}
        for i, k in enumerate(input_def):
            inputs[k] = data[i]
        return inputs

38 39
    def model_arch(self, ):
        pass
Q
qingqing01 已提交
40 41 42 43 44 45

    def get_loss(self, ):
        raise NotImplementedError("Should implement get_loss method!")

    def get_pred(self, ):
        raise NotImplementedError("Should implement get_pred method!")