mmoe_train.py 7.5 KB
Newer Older
Z
zhang wenhui 已提交
1
import paddle.fluid as fluid
O
overlordmax 已提交
2
import pandas as pd
Z
zhang wenhui 已提交
3
import numpy as np
O
overlordmax 已提交
4
import paddle
Z
zhang wenhui 已提交
5
import time
O
overlordmax 已提交
6 7
import datetime
import os
O
overlordmax 已提交
8
import utils
Z
zhang wenhui 已提交
9
from args import *
O
overlordmax 已提交
10 11 12 13 14
import logging

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
O
overlordmax 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

def set_zero(var_name,scope=fluid.global_scope(),place=fluid.CPUPlace(),param_type="int64"):
    """
    Set tensor of a Variable to zero.
    Args:
        var_name(str): name of Variable
        scope(Scope): Scope object, default is fluid.global_scope()
        place(Place): Place object, default is fluid.CPUPlace()
        param_type(str): param data type, default is int64
    """
    param = scope.var(var_name).get_tensor()
    param_array = np.zeros(param._get_dims()).astype(param_type)
    param.set(param_array, place)
    
def MMOE(feature_size=499,expert_num=8, gate_num=2, expert_size=16, tower_size=8):
    a_data = fluid.data(name="a", shape=[-1, feature_size], dtype="float32")
    label_income = fluid.data(name="label_income", shape=[-1, 2], dtype="float32", lod_level=0)
    label_marital = fluid.data(name="label_marital", shape=[-1, 2], dtype="float32", lod_level=0)
    
    # f_{i}(x) = activation(W_{i} * x + b), where activation is ReLU according to the paper
    expert_outputs = []
Z
zhang wenhui 已提交
36
    for i in range(0, expert_num):
O
overlordmax 已提交
37 38 39 40 41 42 43 44 45 46 47 48
        expert_output = fluid.layers.fc(input=a_data,
                                       size=expert_size,
                                       act='relu',
                                       bias_attr=fluid.ParamAttr(learning_rate=1.0),
                                       name='expert_' + str(i))
        expert_outputs.append(expert_output)
    expert_concat = fluid.layers.concat(expert_outputs, axis=1)
    expert_concat = fluid.layers.reshape(expert_concat,[-1, expert_num, expert_size])
    
    
    # g^{k}(x) = activation(W_{gk} * x + b), where activation is softmax according to the paper
    output_layers = []
Z
zhang wenhui 已提交
49
    for i in range(0, gate_num):
O
overlordmax 已提交
50
        cur_gate = fluid.layers.fc(input=a_data,
Z
zhang wenhui 已提交
51 52
                                   size=expert_num,
                                   act='softmax',
O
overlordmax 已提交
53
                                   bias_attr=fluid.ParamAttr(learning_rate=1.0),
Z
zhang wenhui 已提交
54
                                   name='gate_' + str(i))
O
overlordmax 已提交
55 56
        # f^{k}(x) = sum_{i=1}^{n}(g^{k}(x)_{i} * f_{i}(x))
        cur_gate_expert = fluid.layers.elementwise_mul(expert_concat, cur_gate, axis=0)  
Z
zhang wenhui 已提交
57
        cur_gate_expert = fluid.layers.reduce_sum(cur_gate_expert, dim=1)
O
overlordmax 已提交
58 59 60 61 62 63 64 65 66 67 68 69
        # Build tower layer
        cur_tower =  fluid.layers.fc(input=cur_gate_expert,
                                  size=tower_size,
                                  act='relu',
                                  name='task_layer_' + str(i))  
        out =  fluid.layers.fc(input=cur_tower,
                               size=2,
                               act='softmax',
                               name='out_' + str(i))
            
        output_layers.append(out)

O
overlordmax 已提交
70 71 72 73 74
    pred_income = fluid.layers.clip(output_layers[0], min=1e-15, max=1.0 - 1e-15)
    pred_marital = fluid.layers.clip(output_layers[1], min=1e-15, max=1.0 - 1e-15)

    cost_income = paddle.fluid.layers.cross_entropy(input=pred_income, label=label_income,soft_label = True)
    cost_marital = paddle.fluid.layers.cross_entropy(input=pred_marital, label=label_marital,soft_label = True)
O
overlordmax 已提交
75 76 77 78 79
    

    label_income_1 = fluid.layers.slice(label_income, axes=[1], starts=[1], ends=[2])
    label_marital_1 = fluid.layers.slice(label_marital, axes=[1], starts=[1], ends=[2])
    
O
overlordmax 已提交
80 81
    auc_income, batch_auc_1, auc_states_1  = fluid.layers.auc(input=pred_income, label=fluid.layers.cast(x=label_income_1, dtype='int64'))
    auc_marital, batch_auc_2, auc_states_2 = fluid.layers.auc(input=pred_marital, label=fluid.layers.cast(x=label_marital_1, dtype='int64'))
O
overlordmax 已提交
82 83 84 85 86 87 88
    
    avg_cost_income = fluid.layers.mean(x=cost_income)
    avg_cost_marital = fluid.layers.mean(x=cost_marital)
    
    cost =  avg_cost_income + avg_cost_marital
    
    return [a_data,label_income,label_marital],cost,output_layers[0],output_layers[1],label_income,label_marital,auc_income,auc_marital,auc_states_1,auc_states_2
Z
zhang wenhui 已提交
89 90 91



O
overlordmax 已提交
92 93 94 95 96 97 98 99 100 101
args = parse_args()
train_path = args.train_data_path
test_path = args.test_data_path
batch_size = args.batch_size
feature_size = args.feature_size
expert_size = args.expert_size
tower_size = args.tower_size
expert_num = args.expert_num
epochs = args.epochs
gate_num = args.gate_num
Z
zhang wenhui 已提交
102

O
overlordmax 已提交
103 104
logger.info("batch_size:{} ,feature_size:{} ,expert_num:{} ,gate_num:{} ,expert_size:{} ,tower_size:{} ,epochs:{} ".format(
    batch_size,feature_size,expert_num,gate_num,expert_size,tower_size,epochs))
Z
zhang wenhui 已提交
105

O
overlordmax 已提交
106 107
train_reader = utils.prepare_reader(train_path,batch_size)
test_reader = utils.prepare_reader(test_path,batch_size)
Z
zhang wenhui 已提交
108

O
overlordmax 已提交
109
data_list,loss,out_1,out_2,label_1,label_2,auc_income,auc_marital,auc_states_1,auc_states_2 = MMOE(feature_size,expert_num,gate_num,expert_size,tower_size)   
Z
zhang wenhui 已提交
110 111


O
overlordmax 已提交
112 113 114 115 116 117
Adam = fluid.optimizer.AdamOptimizer()
Adam.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
test_program = fluid.default_main_program().clone(for_test=True)
Z
zhang wenhui 已提交
118 119


O
overlordmax 已提交
120 121
loader = fluid.io.DataLoader.from_generator(feed_list=data_list, capacity=batch_size, iterable=True)
loader.set_sample_list_generator(train_reader, places=place)
Z
zhang wenhui 已提交
122

O
overlordmax 已提交
123 124
test_loader = fluid.io.DataLoader.from_generator(feed_list=data_list, capacity=batch_size, iterable=True)
test_loader.set_sample_list_generator(test_reader, places=place)
O
overlordmax 已提交
125 126
auc_income_list = []
auc_marital_list = []
O
overlordmax 已提交
127 128 129 130 131
for epoch in range(epochs):
    for var in auc_states_1:  # reset auc states
        set_zero(var.name,place=place)
    for var in auc_states_2:  # reset auc states
        set_zero(var.name,place=place)
Z
zhang wenhui 已提交
132
    begin = time.time()
O
overlordmax 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    auc_1_p = 0.0
    auc_2_p = 0.0
    loss_data =0.0
    for batch_id,train_data in enumerate(loader()):
        
        loss_data,out_income,out_marital,label_income,label_marital,auc_1_p,auc_2_p = exe.run(
                  feed=train_data,
                  fetch_list=[loss.name,out_1,out_2,label_1,label_2,auc_income,auc_marital],
                  return_numpy=True)
    
    for var in auc_states_1:  # reset auc states
        set_zero(var.name,place=place)
    for var in auc_states_2:  # reset auc states
        set_zero(var.name,place=place)    
    test_auc_1_p = 0.0
    test_auc_2_p = 0.0
    for batch_id,test_data in enumerate(test_loader()):
        
        test_out_income,test_out_marital,test_label_income,test_label_marital,test_auc_1_p,test_auc_2_p = exe.run(
                  program=test_program,
                  feed=test_data,
                  fetch_list=[out_1,out_2,label_1,label_2,auc_income,auc_marital],
                  return_numpy=True) 
O
overlordmax 已提交
156 157 158 159 160 161 162
                  
    model_dir = os.path.join(args.model_dir,'epoch_' + str(epoch + 1), "checkpoint")
    main_program = fluid.default_main_program()
    fluid.io.save(main_program,model_dir)

    auc_income_list.append(test_auc_1_p)
    auc_marital_list.append(test_auc_2_p)
Z
zhang wenhui 已提交
163
    end = time.time()
O
overlordmax 已提交
164 165 166 167 168 169

    logger.info("epoch_id:{},epoch_time:{} s,loss:{},train_auc_income:{},train_auc_marital:{},test_auc_income:{},test_auc_marital:{}".format(
        epoch,end - begin,loss_data,auc_1_p,auc_2_p,test_auc_1_p,test_auc_2_p))
        
logger.info("mean_sb_test_auc_income:{},mean_sb_test_auc_marital:{},max_sb_test_auc_income:{},max_sb_test_auc_marital:{}".format(
        np.mean(auc_income_list),np.mean(auc_marital_list),np.max(auc_income_list),np.max(auc_marital_list)))  
O
overlordmax 已提交
170 171 172 173 174 175 176 177