model.py 4.3 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
import abc
T
tangwei 已提交
16 17 18

import paddle.fluid as fluid

19
from paddlerec.core.utils import envs
T
tangwei 已提交
20

T
tangwei 已提交
21

T
tangwei 已提交
22 23 24 25 26 27 28 29 30 31 32
class Model(object):
    """R
    """
    __metaclass__ = abc.ABCMeta

    def __init__(self, config):
        """R
        """
        self._cost = None
        self._metrics = {}
        self._data_var = []
M
malin10 已提交
33 34
        self._infer_data_var = []
        self._infer_results = {}
T
tangwei 已提交
35
        self._data_loader = None
M
malin10 已提交
36
        self._infer_data_loader = None
T
tangwei 已提交
37
        self._fetch_interval = 20
T
tangwei 已提交
38
        self._namespace = "train.model"
T
tangwei 已提交
39
        self._platform = envs.get_platform()
X
xujiaqi01 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

    def _init_slots(self):
        sparse_slots = envs.get_global_env("sparse_slots", None, "train.reader")
        dense_slots = envs.get_global_env("dense_slots", None, "train.reader")

        if sparse_slots is not None or dense_slots is not None:
            sparse_slots = sparse_slots.strip().split(" ")
            dense_slots = dense_slots.strip().split(" ")
            dense_slots_shape = [[int(j) for j in i.split(":")[1].strip("[]").split(",")] for i in dense_slots]
            dense_slots = [i.split(":")[0] for i in dense_slots]
            self._dense_data_var = []
            for i in range(len(dense_slots)):
                l = fluid.layers.data(name=dense_slots[i], shape=dense_slots_shape[i], dtype="float32")
                self._data_var.append(l)
                self._dense_data_var.append(l)
            self._sparse_data_var = []
            for name in sparse_slots:
                l = fluid.layers.data(name=name, shape=[1], lod_level=1, dtype="int64")
                self._data_var.append(l)
                self._sparse_data_var.append(l)

        dataset_class = envs.get_global_env("dataset_class", None, "train.reader")
        if dataset_class == "DataLoader":
            self._init_dataloader()

    def _init_dataloader(self):
        self._data_loader = fluid.io.DataLoader.from_generator(
            feed_list=self._data_var, capacity=64, use_double_buffer=False, iterable=False)
T
tangwei 已提交
68 69 70 71

    def get_inputs(self):
        return self._data_var

M
malin10 已提交
72 73 74 75 76 77
    def get_infer_inputs(self):
        return self._infer_data_var

    def get_infer_results(self):
        return self._infer_results

T
tangwei 已提交
78
    def get_avg_cost(self):
T
tangwei 已提交
79 80 81 82 83 84 85 86 87 88 89 90
        """R
        """
        return self._cost

    def get_metrics(self):
        """R
        """
        return self._metrics

    def get_fetch_period(self):
        return self._fetch_interval

T
tangwei 已提交
91 92 93 94
    def _build_optimizer(self, name, lr):
        name = name.upper()
        optimizers = ["SGD", "ADAM", "ADAGRAD"]
        if name not in optimizers:
C
chengmo 已提交
95 96
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
97 98

        if name == "SGD":
C
chengmo 已提交
99 100 101 102
            reg = envs.get_global_env(
                "hyper_parameters.reg", 0.0001, self._namespace)
            optimizer_i = fluid.optimizer.SGD(
                lr, regularization=fluid.regularizer.L2DecayRegularizer(reg))
T
tangwei 已提交
103 104 105
        elif name == "ADAM":
            optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
        elif name == "ADAGRAD":
106
            optimizer_i = fluid.optimizer.Adagrad(lr)
T
tangwei 已提交
107
        else:
C
chengmo 已提交
108 109
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
110 111 112 113

        return optimizer_i

    def optimizer(self):
C
chengmo 已提交
114 115 116 117 118
        learning_rate = envs.get_global_env(
            "hyper_parameters.learning_rate", None, self._namespace)
        optimizer = envs.get_global_env(
            "hyper_parameters.optimizer", None, self._namespace)
        print(">>>>>>>>>>>.learnig rate: %s" % learning_rate)
T
tangwei 已提交
119 120
        return self._build_optimizer(optimizer, learning_rate)

T
tangwei 已提交
121 122 123 124 125 126
    @abc.abstractmethod
    def train_net(self):
        """R
        """
        pass

T
tangwei 已提交
127
    @abc.abstractmethod
T
tangwei 已提交
128 129
    def infer_net(self):
        pass