meta_arch.py 1.6 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph.base import to_variable
from ppdet.core.workspace import register
from ppdet.utils.data_structure import BufferDict

__all__ = ['BaseArch']


@register
class BaseArch(Layer):
16
    def __init__(self, *args, **kwargs):
F
FDInSky 已提交
17
        super(BaseArch, self).__init__()
18 19
        self.args = args
        self.kwargs = kwargs
F
FDInSky 已提交
20

21 22 23 24 25 26 27
    def forward(self, inputs, inputs_keys):
        self.gbd = BufferDict()
        self.gbd.update(self.kwargs)
        assert self.gbd[
            'mode'] is not None, "Please specify mode train or infer in config file!"
        if self.kwargs['open_debug'] is None:
            self.gbd['open_debug'] = False
F
FDInSky 已提交
28

29
        self.build_inputs(inputs, inputs_keys)
F
FDInSky 已提交
30

31 32 33 34 35 36 37 38 39 40 41
        self.model_arch()

        self.gbd.debug()

        if self.gbd['mode'] == 'train':
            out = self.loss()
        elif self.gbd['mode'] == 'infer':
            out = self.infer()
        else:
            raise "Now, only support train or infer mode!"
        return out
F
FDInSky 已提交
42 43 44 45

    def build_inputs(self, inputs, inputs_keys):
        for i, k in enumerate(inputs_keys):
            v = to_variable(np.array([x[i] for x in inputs]))
46 47 48 49 50 51 52 53 54 55
            self.gbd.set(k, v)

    def model_arch(self, ):
        raise NotImplementedError("Should implement model_arch method!")

    def loss(self, ):
        raise NotImplementedError("Should implement loss method!")

    def infer(self, ):
        raise NotImplementedError("Should implement infer method!")