model.py 8.1 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
import abc
C
Chengmo 已提交
16
import os
T
tangwei 已提交
17 18
import paddle.fluid as fluid

19
from paddlerec.core.utils import envs
T
tangwei 已提交
20

T
tangwei 已提交
21

C
Chengmo 已提交
22
class ModelBase(object):
T
tangwei 已提交
23
    """Base Model
T
tangwei 已提交
24 25 26 27 28 29 30 31 32
    """
    __metaclass__ = abc.ABCMeta

    def __init__(self, config):
        """R
        """
        self._cost = None
        self._metrics = {}
        self._data_var = []
M
malin10 已提交
33 34
        self._infer_data_var = []
        self._infer_results = {}
T
tangwei 已提交
35
        self._data_loader = None
M
malin10 已提交
36
        self._infer_data_loader = None
T
tangwei 已提交
37
        self._fetch_interval = 20
T
tangwei 已提交
38
        self._platform = envs.get_platform()
F
frankwhzhang 已提交
39
        self._init_hyper_parameters()
T
tangwei 已提交
40
        self._env = config
X
fix  
xjqbest 已提交
41
        self._slot_inited = False
F
frankwhzhang 已提交
42 43 44

    def _init_hyper_parameters(self):
        pass
X
xujiaqi01 已提交
45

X
fix  
xjqbest 已提交
46 47 48 49 50
    def _init_slots(self, **kargs):
        if self._slot_inited:
            return
        self._slot_inited = True
        dataset = {}
X
fix  
xjqbest 已提交
51
        model_dict = {}
T
tangwei 已提交
52
        for i in self._env["phase"]:
X
fix  
xjqbest 已提交
53 54 55
            if i["name"] == kargs["name"]:
                model_dict = i
                break
T
tangwei 已提交
56
        for i in self._env["dataset"]:
X
fix  
xjqbest 已提交
57 58 59 60
            if i["name"] == model_dict["dataset_name"]:
                dataset = i
                break
        name = "dataset." + dataset["name"] + "."
X
fix  
xjqbest 已提交
61 62 63 64 65 66 67 68 69 70 71
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
        if sparse_slots != "" or dense_slots != "":
            if sparse_slots == "":
                sparse_slots = []
            else:
                sparse_slots = sparse_slots.strip().split(" ")
            if dense_slots == "":
                dense_slots = []
            else:
                dense_slots = dense_slots.strip().split(" ")
T
tangwei 已提交
72 73 74
            dense_slots_shape = [[
                int(j) for j in i.split(":")[1].strip("[]").split(",")
            ] for i in dense_slots]
X
xujiaqi01 已提交
75 76 77
            dense_slots = [i.split(":")[0] for i in dense_slots]
            self._dense_data_var = []
            for i in range(len(dense_slots)):
T
tangwei 已提交
78 79 80 81
                l = fluid.layers.data(
                    name=dense_slots[i],
                    shape=dense_slots_shape[i],
                    dtype="float32")
X
xujiaqi01 已提交
82 83 84 85
                self._data_var.append(l)
                self._dense_data_var.append(l)
            self._sparse_data_var = []
            for name in sparse_slots:
T
tangwei 已提交
86 87
                l = fluid.layers.data(
                    name=name, shape=[1], lod_level=1, dtype="int64")
X
xujiaqi01 已提交
88 89 90
                self._data_var.append(l)
                self._sparse_data_var.append(l)

X
fix  
xjqbest 已提交
91 92 93
        dataset_class = dataset["type"]
        if dataset_class == "DataLoader":
            self._init_dataloader()
X
xujiaqi01 已提交
94

X
fix  
xjqbest 已提交
95 96 97 98 99
    def _init_dataloader(self, is_infer=False):
        if is_infer:
            data = self._infer_data_var
        else:
            data = self._data_var
X
xujiaqi01 已提交
100
        self._data_loader = fluid.io.DataLoader.from_generator(
X
fix  
xjqbest 已提交
101
            feed_list=data,
T
tangwei 已提交
102 103 104
            capacity=64,
            use_double_buffer=False,
            iterable=False)
T
tangwei 已提交
105 106 107 108

    def get_inputs(self):
        return self._data_var

M
malin10 已提交
109 110 111 112 113 114
    def get_infer_inputs(self):
        return self._infer_data_var

    def get_infer_results(self):
        return self._infer_results

T
tangwei 已提交
115
    def get_avg_cost(self):
T
tangwei 已提交
116 117 118 119 120 121 122 123 124 125 126 127
        """R
        """
        return self._cost

    def get_metrics(self):
        """R
        """
        return self._metrics

    def get_fetch_period(self):
        return self._fetch_interval

X
test  
xjqbest 已提交
128
    def _build_optimizer(self, name, lr, strategy=None):
T
tangwei 已提交
129 130 131
        name = name.upper()
        optimizers = ["SGD", "ADAM", "ADAGRAD"]
        if name not in optimizers:
C
chengmo 已提交
132 133
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
134

C
Chengmo 已提交
135 136 137 138 139
        if name == "SGD":
            os.environ["FLAGS_communicator_is_sgd_optimizer"] = '1'
        else:
            os.environ["FLAGS_communicator_is_sgd_optimizer"] = '0'

T
tangwei 已提交
140
        if name == "SGD":
T
tangwei 已提交
141
            reg = envs.get_global_env("hyper_parameters.reg", 0.0001)
C
chengmo 已提交
142 143
            optimizer_i = fluid.optimizer.SGD(
                lr, regularization=fluid.regularizer.L2DecayRegularizer(reg))
T
tangwei 已提交
144 145 146
        elif name == "ADAM":
            optimizer_i = fluid.optimizer.Adam(lr, lazy_mode=True)
        elif name == "ADAGRAD":
147
            optimizer_i = fluid.optimizer.Adagrad(lr)
T
tangwei 已提交
148
        else:
C
chengmo 已提交
149 150
            raise ValueError(
                "configured optimizer can only supported SGD/Adam/Adagrad")
T
tangwei 已提交
151 152 153 154

        return optimizer_i

    def optimizer(self):
M
malin10 已提交
155
        opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
M
malin10 已提交
156 157 158 159 160
        opt_lr = envs.get_global_env(
            "hyper_parameters.optimizer.learning_rate")
        opt_strategy = envs.get_global_env(
            "hyper_parameters.optimizer.strategy")

M
malin10 已提交
161
        return self._build_optimizer(opt_name, opt_lr, opt_strategy)
T
tangwei 已提交
162

X
fix  
xjqbest 已提交
163 164
    def input_data(self, is_infer=False, **kwargs):
        name = "dataset." + kwargs.get("dataset_name") + "."
X
fix  
xjqbest 已提交
165 166
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
X
fix  
xjqbest 已提交
167 168
        self._sparse_data_var_map = {}
        self._dense_data_var_map = {}
X
fix  
xjqbest 已提交
169 170 171 172 173 174 175 176 177
        if sparse_slots != "" or dense_slots != "":
            if sparse_slots == "":
                sparse_slots = []
            else:
                sparse_slots = sparse_slots.strip().split(" ")
            if dense_slots == "":
                dense_slots = []
            else:
                dense_slots = dense_slots.strip().split(" ")
178 179 180 181 182 183 184 185 186 187 188 189 190
            dense_slots_shape = [[
                int(j) for j in i.split(":")[1].strip("[]").split(",")
            ] for i in dense_slots]
            dense_slots = [i.split(":")[0] for i in dense_slots]
            self._dense_data_var = []
            data_var_ = []
            for i in range(len(dense_slots)):
                l = fluid.layers.data(
                    name=dense_slots[i],
                    shape=dense_slots_shape[i],
                    dtype="float32")
                data_var_.append(l)
                self._dense_data_var.append(l)
X
fix  
xjqbest 已提交
191
                self._dense_data_var_map[dense_slots[i]] = l
192 193 194 195 196 197
            self._sparse_data_var = []
            for name in sparse_slots:
                l = fluid.layers.data(
                    name=name, shape=[1], lod_level=1, dtype="int64")
                data_var_.append(l)
                self._sparse_data_var.append(l)
X
fix  
xjqbest 已提交
198
                self._sparse_data_var_map[name] = l
199 200 201 202
            return data_var_

        else:
            return None
F
frankwhzhang 已提交
203 204 205 206

    def net(self, is_infer=False):
        return None

F
frankwhzhang 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    def _construct_reader(self, is_infer=False):
        if is_infer:
            self._infer_data_loader = fluid.io.DataLoader.from_generator(
                feed_list=self._infer_data_var,
                capacity=64,
                use_double_buffer=False,
                iterable=False)
        else:
            dataset_class = envs.get_global_env("dataset_class", None,
                                                "train.reader")
            if dataset_class == "DataLoader":
                self._data_loader = fluid.io.DataLoader.from_generator(
                    feed_list=self._data_var,
                    capacity=64,
                    use_double_buffer=False,
                    iterable=False)

T
tangwei 已提交
224
    def train_net(self):
F
frankwhzhang 已提交
225 226
        input_data = self.input_data(is_infer=False)
        self._data_var = input_data
F
frankwhzhang 已提交
227
        self._construct_reader(is_infer=False)
F
frankwhzhang 已提交
228
        self.net(input_data, is_infer=False)
T
tangwei 已提交
229 230

    def infer_net(self):
F
frankwhzhang 已提交
231 232
        input_data = self.input_data(is_infer=True)
        self._infer_data_var = input_data
F
frankwhzhang 已提交
233
        self._construct_reader(is_infer=True)
F
frankwhzhang 已提交
234
        self.net(input_data, is_infer=True)