meta_arch.py 1.2 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
W
wangxinxin08 已提交
6 7
import paddle
import paddle.nn as nn
F
FDInSky 已提交
8 9 10 11 12 13 14
from ppdet.core.workspace import register
from ppdet.utils.data_structure import BufferDict

__all__ = ['BaseArch']


@register
W
wangxinxin08 已提交
15
class BaseArch(nn.Layer):
16
    def __init__(self):
F
FDInSky 已提交
17 18
        super(BaseArch, self).__init__()

19 20 21
    def forward(self, data, input_def, mode):
        self.inputs = self.build_inputs(data, input_def)
        self.inputs['mode'] = mode
22 23
        self.model_arch()

24
        if mode == 'train':
K
Kaipeng Deng 已提交
25
            out = self.get_loss()
26
        elif mode == 'infer':
K
Kaipeng Deng 已提交
27
            out = self.get_pred()
28 29 30
        else:
            raise "Now, only support train or infer mode!"
        return out
F
FDInSky 已提交
31

32 33
    def build_inputs(self, data, input_def):
        inputs = {}
G
Guanghua Yu 已提交
34 35 36
        for i, k in enumerate(input_def):
            v = paddle.to_tensor(data[i])
            inputs[k] = v
37 38
        return inputs

W
wangxinxin08 已提交
39
    def model_arch(self):
40 41
        raise NotImplementedError("Should implement model_arch method!")

K
Kaipeng Deng 已提交
42 43
    def get_loss(self, ):
        raise NotImplementedError("Should implement get_loss method!")
44

K
Kaipeng Deng 已提交
45 46
    def get_pred(self, ):
        raise NotImplementedError("Should implement get_pred method!")