meta_arch.py 1.5 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph.base import to_variable
from ppdet.core.workspace import register
from ppdet.utils.data_structure import BufferDict

__all__ = ['BaseArch']


@register
class BaseArch(Layer):
16
    def __init__(self):
F
FDInSky 已提交
17 18
        super(BaseArch, self).__init__()

19 20 21
    def forward(self, data, input_def, mode):
        self.inputs = self.build_inputs(data, input_def)
        self.inputs['mode'] = mode
22 23
        self.model_arch()

24
        if mode == 'train':
25
            out = self.loss()
26
        elif mode == 'infer':
27 28 29 30
            out = self.infer()
        else:
            raise "Now, only support train or infer mode!"
        return out
F
FDInSky 已提交
31

32 33 34 35 36 37 38 39 40 41 42 43 44 45
    def build_inputs(self, data, input_def):
        inputs = {}
        for name in input_def:
            inputs[name] = []
        batch_size = len(data)
        for bs in range(batch_size):
            for name, input in zip(input_def, data[bs]):
                input_v = np.array(input)[np.newaxis, ...]
                inputs[name].append(input_v)
        for name in input_def:
            inputs[name] = to_variable(np.concatenate(inputs[name]))
        return inputs

    def model_arch(self, mode):
46 47 48 49 50 51 52
        raise NotImplementedError("Should implement model_arch method!")

    def loss(self, ):
        raise NotImplementedError("Should implement loss method!")

    def infer(self, ):
        raise NotImplementedError("Should implement infer method!")