model.py 4.7 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
import abc
T
tangwei 已提交
16 17 18

import paddle.fluid as fluid

19
from paddlerec.core.utils import envs
T
tangwei 已提交
20

T
tangwei 已提交
21

T
tangwei 已提交
22
class Model(object):
T
tangwei 已提交
23
    """Base Model
T
tangwei 已提交
24 25 26 27 28 29 30 31 32
    """
    __metaclass__ = abc.ABCMeta

    def __init__(self, config):
        """R
        """
        self._cost = None
        self._metrics = {}
        self._data_var = []
M
malin10 已提交
33 34
        self._infer_data_var = []
        self._infer_results = {}
T
tangwei 已提交
35
        self._data_loader = None
M
malin10 已提交
36
        self._infer_data_loader = None
T
tangwei 已提交
37
        self._fetch_interval = 20
T
tangwei 已提交
38
        self._namespace = "train.model"
T
tangwei 已提交
39
        self._platform = envs.get_platform()
F
frankwhzhang 已提交
40 41 42 43
        self._init_hyper_parameters()

    def _init_hyper_parameters(self):
        pass
X
xujiaqi01 已提交
44 45

    def _init_slots(self):
T
tangwei 已提交
46 47
        sparse_slots = envs.get_global_env("sparse_slots", None,
                                           "train.reader")
X
xujiaqi01 已提交
48 49 50 51 52
        dense_slots = envs.get_global_env("dense_slots", None, "train.reader")

        if sparse_slots is not None or dense_slots is not None:
            sparse_slots = sparse_slots.strip().split(" ")
            dense_slots = dense_slots.strip().split(" ")
T
tangwei 已提交
53 54 55
            dense_slots_shape = [[
                int(j) for j in i.split(":")[1].strip("[]").split(",")
            ] for i in dense_slots]
X
xujiaqi01 已提交
56 57 58
            dense_slots = [i.split(":")[0] for i in dense_slots]
            self._dense_data_var = []
            for i in range(len(dense_slots)):
T
tangwei 已提交
59 60 61 62
                l = fluid.layers.data(
                    name=dense_slots[i],
                    shape=dense_slots_shape[i],
                    dtype="float32")
X
xujiaqi01 已提交
63 64 65 66
                self._data_var.append(l)
                self._dense_data_var.append(l)
            self._sparse_data_var = []
            for name in sparse_slots:
T
tangwei 已提交
67 68
                l = fluid.layers.data(
                    name=name, shape=[1], lod_level=1, dtype="int64")
X
xujiaqi01 已提交
69 70 71
                self._data_var.append(l)
                self._sparse_data_var.append(l)

T
tangwei 已提交
72 73
        dataset_class = envs.get_global_env("dataset_class", None,
                                            "train.reader")
X
xujiaqi01 已提交
74 75 76 77 78
        if dataset_class == "DataLoader":
            self._init_dataloader()

    def _init_dataloader(self):
        self._data_loader = fluid.io.DataLoader.from_generator(
T
tangwei 已提交
79 80 81 82
            feed_list=self._data_var,
            capacity=64,
            use_double_buffer=False,
            iterable=False)
T
tangwei 已提交
83 84 85 86

    def get_inputs(self):
        return self._data_var

M
malin10 已提交
87 88 89 90 91 92
    def get_infer_inputs(self):
        return self._infer_data_var

    def get_infer_results(self):
        return self._infer_results

T
tangwei 已提交
93
    def get_avg_cost(self):
T
tangwei 已提交
94 95 96 97 98 99 100 101 102 103 104 105
        """R
        """
        return self._cost

    def get_metrics(self):
        """R
        """
        return self._metrics

    def get_fetch_period(self):
        return self._fetch_interval

T
tangwei 已提交
106 107 108 109
    def _build_optimizer(self, name, lr):
        name = name.upper()
        optimizers = ["SGD", "ADAM", "ADAGRAD"]
        if name not in optimizers:
C
chengmo 已提交
110 111
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
112 113

        if name == "SGD":
T
tangwei 已提交
114 115
            reg = envs.get_global_env("hyper_parameters.reg", 0.0001,
                                      self._namespace)
C
chengmo 已提交
116 117
            optimizer_i = fluid.optimizer.SGD(
                lr, regularization=fluid.regularizer.L2DecayRegularizer(reg))
T
tangwei 已提交
118 119 120
        elif name == "ADAM":
            optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
        elif name == "ADAGRAD":
121
            optimizer_i = fluid.optimizer.Adagrad(lr)
T
tangwei 已提交
122
        else:
C
chengmo 已提交
123 124
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
125 126 127 128

        return optimizer_i

    def optimizer(self):
T
tangwei 已提交
129 130 131 132
        learning_rate = envs.get_global_env("hyper_parameters.learning_rate",
                                            None, self._namespace)
        optimizer = envs.get_global_env("hyper_parameters.optimizer", None,
                                        self._namespace)
C
chengmo 已提交
133
        print(">>>>>>>>>>>.learnig rate: %s" % learning_rate)
T
tangwei 已提交
134 135
        return self._build_optimizer(optimizer, learning_rate)

T
tangwei 已提交
136 137 138 139 140 141
    @abc.abstractmethod
    def train_net(self):
        """R
        """
        pass

T
tangwei 已提交
142
    @abc.abstractmethod
T
tangwei 已提交
143 144
    def infer_net(self):
        pass