meta_arch.py 1.4 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
W
wangxinxin08 已提交
6 7
import paddle
import paddle.nn as nn
F
FDInSky 已提交
8 9 10 11 12 13
from ppdet.core.workspace import register

__all__ = ['BaseArch']


@register
W
wangxinxin08 已提交
14
class BaseArch(nn.Layer):
15
    def __init__(self):
F
FDInSky 已提交
16 17
        super(BaseArch, self).__init__()

18 19 20 21 22
    def forward(self,
                input_tensor=None,
                data=None,
                input_def=None,
                mode='infer'):
23
        if input_tensor is None:
24
            assert data is not None and input_def is not None
25 26 27
            self.inputs = self.build_inputs(data, input_def)
        else:
            self.inputs = input_tensor
28

29
        self.inputs['mode'] = mode
30 31
        self.model_arch()

32
        if mode == 'train':
K
Kaipeng Deng 已提交
33
            out = self.get_loss()
34
        elif mode == 'infer':
35
            out = self.get_pred()
36
        else:
37 38
            out = None
            raise "Now, only support train and infer mode!"
39
        return out
F
FDInSky 已提交
40

41 42
    def build_inputs(self, data, input_def):
        inputs = {}
G
Guanghua Yu 已提交
43
        for i, k in enumerate(input_def):
W
wangguanzhong 已提交
44
            inputs[k] = data[i]
45 46
        return inputs

W
wangxinxin08 已提交
47
    def model_arch(self):
48 49
        raise NotImplementedError("Should implement model_arch method!")

K
Kaipeng Deng 已提交
50 51
    def get_loss(self, ):
        raise NotImplementedError("Should implement get_loss method!")
52

K
Kaipeng Deng 已提交
53 54
    def get_pred(self, ):
        raise NotImplementedError("Should implement get_pred method!")