wide_and_deep.py 19.0 KB
Newer Older
S
Su Teng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""wide and deep model"""
P
panbingao 已提交
16
import numpy as np
Y
yao_yf 已提交
17
from mindspore import nn, context
Y
yao_yf 已提交
18
from mindspore import Parameter, ParameterTuple
S
Su Teng 已提交
19 20 21 22
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
23
from mindspore.nn import Dropout
Y
yao_yf 已提交
24
from mindspore.nn.optim import Adam, FTRL, LazyAdam
S
Su Teng 已提交
25
from mindspore.common.initializer import Uniform, initializer
Y
yao_yf 已提交
26
from mindspore.context import ParallelMode
Y
yao_yf 已提交
27 28
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
S
Su Teng 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

np_type = np.float32
ms_type = mstype.float32


def init_method(method, shape, name, max_val=1.0):
    '''
    parameter init method
    '''
    if method in ['uniform']:
        params = Parameter(initializer(
            Uniform(max_val), shape, ms_type), name=name)
    elif method == "one":
        params = Parameter(initializer("ones", shape, ms_type), name=name)
    elif method == 'zero':
        params = Parameter(initializer("zeros", shape, ms_type), name=name)
    elif method == "normal":
Y
yao_yf 已提交
46
        params = Parameter(initializer("normal", shape, ms_type), name=name)
S
Su Teng 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    return params


def init_var_dict(init_args, in_vars):
    '''
    var init function
    '''
    var_map = {}
    _, _max_val = init_args
    for _, iterm in enumerate(in_vars):
        key, shape, method = iterm
        if key not in var_map.keys():
            if method in ['random', 'uniform']:
                var_map[key] = Parameter(initializer(
                    Uniform(_max_val), shape, ms_type), name=key)
            elif method == "one":
                var_map[key] = Parameter(initializer(
                    "ones", shape, ms_type), name=key)
            elif method == "zero":
                var_map[key] = Parameter(initializer(
                    "zeros", shape, ms_type), name=key)
            elif method == 'normal':
Y
yao_yf 已提交
69 70
                var_map[key] = Parameter(initializer(
                    "normal", shape, ms_type), name=key)
S
Su Teng 已提交
71 72 73 74 75 76 77 78 79 80 81
    return var_map


class DenseLayer(nn.Cell):
    """
    Dense Layer for Deep Layer of WideDeep Model;
    Containing: activation, matmul, bias_add;
    Args:
    """

    def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
Y
yao_yf 已提交
82
                 keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False):
S
Su Teng 已提交
83 84 85 86 87 88 89 90 91
        super(DenseLayer, self).__init__()
        weight_init, bias_init = weight_bias_init
        self.weight = init_method(
            weight_init, [input_dim, output_dim], name="weight")
        self.bias = init_method(bias_init, [output_dim], name="bias")
        self.act_func = self._init_activation(act_str)
        self.matmul = P.MatMul(transpose_b=False)
        self.bias_add = P.BiasAdd()
        self.cast = P.Cast()
92
        self.dropout = Dropout(keep_prob=keep_prob)
Y
yao_yf 已提交
93
        self.use_activation = use_activation
S
Su Teng 已提交
94
        self.convert_dtype = convert_dtype
95
        self.drop_out = drop_out
S
Su Teng 已提交
96 97 98 99 100 101 102 103 104 105 106 107

    def _init_activation(self, act_str):
        act_str = act_str.lower()
        if act_str == "relu":
            act_func = P.ReLU()
        elif act_str == "sigmoid":
            act_func = P.Sigmoid()
        elif act_str == "tanh":
            act_func = P.Tanh()
        return act_func

    def construct(self, x):
108 109 110
        '''
        Construct Dense layer
        '''
111 112
        if self.training and self.drop_out:
            x = self.dropout(x)
S
Su Teng 已提交
113 114 115
        if self.convert_dtype:
            x = self.cast(x, mstype.float16)
            weight = self.cast(self.weight, mstype.float16)
Y
yao_yf 已提交
116
            bias = self.cast(self.bias, mstype.float16)
S
Su Teng 已提交
117
            wx = self.matmul(x, weight)
Y
yao_yf 已提交
118 119 120
            wx = self.bias_add(wx, bias)
            if self.use_activation:
                wx = self.act_func(wx)
S
Su Teng 已提交
121 122 123
            wx = self.cast(wx, mstype.float32)
        else:
            wx = self.matmul(x, self.weight)
Y
yao_yf 已提交
124 125 126 127
            wx = self.bias_add(wx, self.bias)
            if self.use_activation:
                wx = self.act_func(wx)
        return wx
S
Su Teng 已提交
128 129 130 131 132 133 134 135 136 137 138 139


class WideDeepModel(nn.Cell):
    """
        From paper: " Wide & Deep Learning for Recommender Systems"
        Args:
            config (Class): The default config of Wide&Deep
    """

    def __init__(self, config):
        super(WideDeepModel, self).__init__()
        self.batch_size = config.batch_size
Y
yao_yf 已提交
140
        host_device_mix = bool(config.host_device_mix)
141
        parameter_server = bool(config.parameter_server)
Y
yao_yf 已提交
142
        parallel_mode = context.get_auto_parallel_context("parallel_mode")
Y
yao_yf 已提交
143 144
        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
        if is_auto_parallel:
Y
yao_yf 已提交
145
            self.batch_size = self.batch_size * get_group_size()
Y
yao_yf 已提交
146
        is_field_slice = config.field_slice
S
Su Teng 已提交
147 148 149
        self.field_size = config.field_size
        self.vocab_size = config.vocab_size
        self.emb_dim = config.emb_dim
Y
yao_yf 已提交
150 151
        self.deep_layer_dims_list = config.deep_layer_dim
        self.deep_layer_act = config.deep_layer_act
S
Su Teng 已提交
152 153 154 155 156 157 158 159 160 161
        self.init_args = config.init_args
        self.weight_init, self.bias_init = config.weight_bias_init
        self.weight_bias_init = config.weight_bias_init
        self.emb_init = config.emb_init
        self.drop_out = config.dropout_flag
        self.keep_prob = config.keep_prob
        self.deep_input_dims = self.field_size * self.emb_dim
        self.layer_dims = self.deep_layer_dims_list + [1]
        self.all_dim_list = [self.deep_input_dims] + self.layer_dims

Y
yao_yf 已提交
162
        init_acts = [('Wide_b', [1], self.emb_init)]
S
Su Teng 已提交
163 164 165 166 167
        var_map = init_var_dict(self.init_args, init_acts)
        self.wide_b = var_map["Wide_b"]
        self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
                                        self.all_dim_list[1],
                                        self.weight_bias_init,
168 169
                                        self.deep_layer_act,
                                        convert_dtype=True, drop_out=config.dropout_flag)
S
Su Teng 已提交
170 171 172
        self.dense_layer_2 = DenseLayer(self.all_dim_list[1],
                                        self.all_dim_list[2],
                                        self.weight_bias_init,
173 174
                                        self.deep_layer_act,
                                        convert_dtype=True, drop_out=config.dropout_flag)
S
Su Teng 已提交
175 176 177
        self.dense_layer_3 = DenseLayer(self.all_dim_list[2],
                                        self.all_dim_list[3],
                                        self.weight_bias_init,
178 179
                                        self.deep_layer_act,
                                        convert_dtype=True, drop_out=config.dropout_flag)
S
Su Teng 已提交
180 181 182
        self.dense_layer_4 = DenseLayer(self.all_dim_list[3],
                                        self.all_dim_list[4],
                                        self.weight_bias_init,
183 184
                                        self.deep_layer_act,
                                        convert_dtype=True, drop_out=config.dropout_flag)
S
Su Teng 已提交
185 186 187
        self.dense_layer_5 = DenseLayer(self.all_dim_list[4],
                                        self.all_dim_list[5],
                                        self.weight_bias_init,
188
                                        self.deep_layer_act,
Y
yao_yf 已提交
189
                                        use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
Y
yao_yf 已提交
190 191
        self.wide_mul = P.Mul()
        self.deep_mul = P.Mul()
S
Su Teng 已提交
192 193
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.reshape = P.Reshape()
Y
yao_yf 已提交
194
        self.deep_reshape = P.Reshape()
S
Su Teng 已提交
195 196 197 198 199
        self.square = P.Square()
        self.shape = P.Shape()
        self.tile = P.Tile()
        self.concat = P.Concat(axis=1)
        self.cast = P.Cast()
Y
yao_yf 已提交
200
        if is_auto_parallel and host_device_mix and not is_field_slice:
Y
yao_yf 已提交
201
            self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
Y
yao_yf 已提交
202
            self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
Y
yao_yf 已提交
203
            self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
Y
yao_yf 已提交
204 205 206 207
            self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
                                                           slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE)
            self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
                                                           slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_ROW_SLICE)
Y
yao_yf 已提交
208 209 210
            self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
            self.deep_reshape.add_prim_attr("skip_redistribution", True)
            self.reduce_sum.add_prim_attr("cross_batch", True)
Y
yao_yf 已提交
211
            self.embedding_table = self.deep_embeddinglookup.embedding_table
Y
yao_yf 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225
        elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape:
            manual_shapes = tuple((s[0] for s in config.manual_shape))
            self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim,
                                                           slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE,
                                                           manual_shapes=manual_shapes)
            self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1,
                                                           slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE,
                                                           manual_shapes=manual_shapes)
            self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
            self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1)))
            self.reduce_sum.set_strategy(((1, get_group_size(), 1),))
            self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
            self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),))
            self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
Y
yao_yf 已提交
226
            self.embedding_table = self.deep_embeddinglookup.embedding_table
227
        elif parameter_server:
Y
yao_yf 已提交
228 229 230
            self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
            self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
            self.embedding_table = self.deep_embeddinglookup.embedding_table
Z
ZPaC 已提交
231 232
            self.deep_embeddinglookup.embedding_table.set_param_ps()
            self.wide_embeddinglookup.embedding_table.set_param_ps()
Y
yao_yf 已提交
233
        else:
Y
yao_yf 已提交
234 235 236
            self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
            self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')
            self.embedding_table = self.deep_embeddinglookup.embedding_table
S
Su Teng 已提交
237 238 239 240 241 242 243 244 245

    def construct(self, id_hldr, wt_hldr):
        """
        Args:
            id_hldr: batch ids;
            wt_hldr: batch weights;
        """
        mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
        # Wide layer
Y
yao_yf 已提交
246
        wide_id_weight = self.wide_embeddinglookup(id_hldr)
Y
yao_yf 已提交
247
        wx = self.wide_mul(wide_id_weight, mask)
S
Su Teng 已提交
248 249
        wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
        # Deep layer
Y
yao_yf 已提交
250
        deep_id_embs = self.deep_embeddinglookup(id_hldr)
Y
yao_yf 已提交
251 252
        vx = self.deep_mul(deep_id_embs, mask)
        deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
S
Su Teng 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
        deep_in = self.dense_layer_1(deep_in)
        deep_in = self.dense_layer_2(deep_in)
        deep_in = self.dense_layer_3(deep_in)
        deep_in = self.dense_layer_4(deep_in)
        deep_out = self.dense_layer_5(deep_in)
        out = wide_out + deep_out
        return out, self.embedding_table


class NetWithLossClass(nn.Cell):

    """"
    Provide WideDeep training loss through network.
    Args:
        network (Cell): The training network
        config (Class): WideDeep config
    """

    def __init__(self, network, config):
        super(NetWithLossClass, self).__init__(auto_prefix=False)
Y
yao_yf 已提交
273
        host_device_mix = bool(config.host_device_mix)
274
        parameter_server = bool(config.parameter_server)
Y
yao_yf 已提交
275
        parallel_mode = context.get_auto_parallel_context("parallel_mode")
Y
yao_yf 已提交
276
        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
Y
yao_yf 已提交
277
        self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server)
S
Su Teng 已提交
278 279 280 281 282
        self.network = network
        self.l2_coef = config.l2_coef
        self.loss = P.SigmoidCrossEntropyWithLogits()
        self.square = P.Square()
        self.reduceMean_false = P.ReduceMean(keep_dims=False)
Y
yao_yf 已提交
283 284
        if is_auto_parallel:
            self.reduceMean_false.add_prim_attr("cross_batch", True)
S
Su Teng 已提交
285 286 287
        self.reduceSum_false = P.ReduceSum(keep_dims=False)

    def construct(self, batch_ids, batch_wts, label):
288 289 290
        '''
        Construct NetWithLossClass
        '''
S
Su Teng 已提交
291 292 293
        predict, embedding_table = self.network(batch_ids, batch_wts)
        log_loss = self.loss(predict, label)
        wide_loss = self.reduceMean_false(log_loss)
Y
yao_yf 已提交
294 295 296 297 298
        if self.no_l2loss:
            deep_loss = wide_loss
        else:
            l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
            deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
S
Su Teng 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319

        return wide_loss, deep_loss


class IthOutputCell(nn.Cell):
    def __init__(self, network, output_index):
        super(IthOutputCell, self).__init__()
        self.network = network
        self.output_index = output_index

    def construct(self, x1, x2, x3):
        predict = self.network(x1, x2, x3)[self.output_index]
        return predict


class TrainStepWrap(nn.Cell):
    """
    Encapsulation class of WideDeep network training.
    Append Adam and FTRL optimizers to the training network after that construct
    function can be called to create the backward graph.
    Args:
Y
yao_yf 已提交
320 321 322
        network (Cell): The training network. Note that loss function should have been added.
        sens (Number): The adjust parameter. Default: 1024.0
        host_device_mix (Bool): Whether run in host and device mix mode. Default: False
323
        parameter_server (Bool): Whether run in parameter server mode. Default: False
S
Su Teng 已提交
324 325
    """

326
    def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False):
S
Su Teng 已提交
327
        super(TrainStepWrap, self).__init__()
Y
yao_yf 已提交
328
        parallel_mode = context.get_auto_parallel_context("parallel_mode")
Y
yao_yf 已提交
329
        is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
S
Su Teng 已提交
330 331 332 333 334 335 336 337 338 339 340 341
        self.network = network
        self.network.set_train()
        self.trainable_params = network.trainable_params()
        weights_w = []
        weights_d = []
        for params in self.trainable_params:
            if 'wide' in params.name:
                weights_w.append(params)
            else:
                weights_d.append(params)
        self.weights_w = ParameterTuple(weights_w)
        self.weights_d = ParameterTuple(weights_d)
Y
yao_yf 已提交
342

Z
ZPaC 已提交
343
        if (host_device_mix and is_auto_parallel) or parameter_server:
Y
yao_yf 已提交
344 345
            self.optimizer_d = LazyAdam(
                self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
346 347 348 349
            self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
                                    l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
            self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU")
            self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU")
Y
yao_yf 已提交
350 351 352 353 354
        else:
            self.optimizer_d = Adam(
                self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
            self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
                                    l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
S
Su Teng 已提交
355
        self.hyper_map = C.HyperMap()
P
panyifeng 已提交
356
        self.grad_w = C.GradOperation(get_by_list=True,
S
Su Teng 已提交
357
                                      sens_param=True)
P
panyifeng 已提交
358
        self.grad_d = C.GradOperation(get_by_list=True,
S
Su Teng 已提交
359 360 361 362 363
                                      sens_param=True)
        self.sens = sens
        self.loss_net_w = IthOutputCell(network, output_index=0)
        self.loss_net_d = IthOutputCell(network, output_index=1)

Y
yao_yf 已提交
364 365 366 367 368 369
        self.reducer_flag = False
        self.grad_reducer_w = None
        self.grad_reducer_d = None
        self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
                                              ParallelMode.HYBRID_PARALLEL)
        if self.reducer_flag:
Y
yao_yf 已提交
370
            mean = context.get_auto_parallel_context("gradients_mean")
Y
yao_yf 已提交
371
            degree = context.get_auto_parallel_context("device_num")
Y
yao_yf 已提交
372 373 374
            self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
            self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)

S
Su Teng 已提交
375
    def construct(self, batch_ids, batch_wts, label):
376 377 378
        '''
        Construct wide and deep model
        '''
S
Su Teng 已提交
379 380 381 382 383 384 385 386 387
        weights_w = self.weights_w
        weights_d = self.weights_d
        loss_w, loss_d = self.network(batch_ids, batch_wts, label)
        sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens)
        sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens)
        grads_w = self.grad_w(self.loss_net_w, weights_w)(batch_ids, batch_wts,
                                                          label, sens_w)
        grads_d = self.grad_d(self.loss_net_d, weights_d)(batch_ids, batch_wts,
                                                          label, sens_d)
Y
yao_yf 已提交
388 389 390
        if self.reducer_flag:
            grads_w = self.grad_reducer_w(grads_w)
            grads_d = self.grad_reducer_d(grads_d)
S
Su Teng 已提交
391 392 393 394 395 396 397 398 399 400 401
        return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d,
                                                                     self.optimizer_d(grads_d))


class PredictWithSigmoid(nn.Cell):
    def __init__(self, network):
        super(PredictWithSigmoid, self).__init__()
        self.network = network
        self.sigmoid = P.Sigmoid()

    def construct(self, batch_ids, batch_wts, labels):
C
chujinjin 已提交
402
        logits, _, = self.network(batch_ids, batch_wts)
S
Su Teng 已提交
403 404
        pred_probs = self.sigmoid(logits)
        return logits, pred_probs, labels