model.py 2.4 KB
Newer Older
T
tangwei 已提交
1
import abc
T
tangwei 已提交
2 3 4

import paddle.fluid as fluid

5
from paddlerec.core.utils import envs
T
tangwei 已提交
6

T
tangwei 已提交
7

T
tangwei 已提交
8 9 10 11 12 13 14 15 16 17 18
class Model(object):
    """R
    """
    __metaclass__ = abc.ABCMeta

    def __init__(self, config):
        """R
        """
        self._cost = None
        self._metrics = {}
        self._data_var = []
M
malin10 已提交
19 20
        self._infer_data_var = []
        self._infer_results = {}
T
tangwei 已提交
21
        self._data_loader = None
M
malin10 已提交
22
        self._infer_data_loader = None
T
tangwei 已提交
23
        self._fetch_interval = 20
T
tangwei 已提交
24
        self._namespace = "train.model"
T
tangwei 已提交
25
        self._platform = envs.get_platform()
T
tangwei 已提交
26 27 28 29

    def get_inputs(self):
        return self._data_var

M
malin10 已提交
30 31 32 33 34 35
    def get_infer_inputs(self):
        return self._infer_data_var

    def get_infer_results(self):
        return self._infer_results

T
tangwei 已提交
36 37 38 39 40 41 42 43 44 45
    def get_cost_op(self):
        """R
        """
        return self._cost

    def get_metrics(self):
        """R
        """
        return self._metrics

C
chengmo 已提交
46 47 48 49 50 51
    def custom_preprocess(self):
        """
        do something after exe.run(stratup_program) and before run()
        """
        pass

T
tangwei 已提交
52 53 54
    def get_fetch_period(self):
        return self._fetch_interval

T
tangwei 已提交
55 56 57 58
    def _build_optimizer(self, name, lr):
        name = name.upper()
        optimizers = ["SGD", "ADAM", "ADAGRAD"]
        if name not in optimizers:
C
chengmo 已提交
59 60
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
61 62

        if name == "SGD":
C
chengmo 已提交
63 64 65 66
            reg = envs.get_global_env(
                "hyper_parameters.reg", 0.0001, self._namespace)
            optimizer_i = fluid.optimizer.SGD(
                lr, regularization=fluid.regularizer.L2DecayRegularizer(reg))
T
tangwei 已提交
67 68 69
        elif name == "ADAM":
            optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
        elif name == "ADAGRAD":
70
            optimizer_i = fluid.optimizer.Adagrad(lr)
T
tangwei 已提交
71
        else:
C
chengmo 已提交
72 73
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
74 75 76 77

        return optimizer_i

    def optimizer(self):
C
chengmo 已提交
78 79 80 81 82
        learning_rate = envs.get_global_env(
            "hyper_parameters.learning_rate", None, self._namespace)
        optimizer = envs.get_global_env(
            "hyper_parameters.optimizer", None, self._namespace)
        print(">>>>>>>>>>>.learnig rate: %s" % learning_rate)
T
tangwei 已提交
83 84
        return self._build_optimizer(optimizer, learning_rate)

T
tangwei 已提交
85 86 87 88 89 90
    @abc.abstractmethod
    def train_net(self):
        """R
        """
        pass

T
tangwei 已提交
91
    @abc.abstractmethod
T
tangwei 已提交
92 93
    def infer_net(self):
        pass