提交 58ab9594 编写于 作者: O overlordmax

add MMOE model

上级 6b882d42
# MMOE
以下是本例的简要目录结构及说明:
```python
├── README.md # 文档
├── train_mmoe.py # mmoe模型脚本
├── utils # 通用函数
├── args # 参数脚本
├── create_data.sh # 生成训练数据脚本
├── train_path # 原始训练数据文件
├── test_path # 原始测试数据文件
├── train_data_path # 训练数据路径
└── test_data_path # 测试数据路径
```
## 简介
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用shared-bottom的结构,不同任务间共用底部的隐层。 论文[MMOE][ https://dl.acm.org/doi/10.1145/3219819.3220007 ]中提出了一个Multi-gate Mixture-of-Experts(MMoE)的多任务学习结构。MMoE模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。
## 数据下载及预处理
数据地址:https://archive.ics.uci.edu/ml/datasets/Census-Income+(KDD
数据解压后, 在create_data.sh脚本文件中添加文件的路径,并运行脚本。
```shell
mkdir data/data24913/train_data #新建训练数据目录
mkdir data/data24913/test_data #新建测试数据目录
mkdir data/data24913/validation_data #新建验证数据目录
train_path="data/data24913/census-income.data" #原始训练数据路径
test_path="data/data24913/census-income.test" #原始测试数据路径
train_data_path="data/data24913/train_data/" #处理后训练数据路径
test_data_path="data/data24913/test_data/" #处理后测试数据路径
validation_data_path="data/data24913/validation_data/" #处理后验证数据路径
python data_preparation.py --train_path ${train_path} \
--test_path ${test_path} \
--train_data_path ${train_data_path}\
--test_data_path ${test_data_path}\
--validation_data_path ${validation_data_path}
```
## 单机训练
GPU环境
```shell
python train_mmoe.py --use_gpu True
--train_path data/data24913/train_data/
--test_path data/data24913/test_data/
--batch_size 32
--expert_num 8
--gate_num 2
--epochs 400
```
CPU环境
```shell
python train_mmoe.py --use_gpu False
--train_path data/data24913/train_data/
--test_path data/data24913/test_data/
--batch_size 32
--expert_num 8
--gate_num 2
--epochs 400
```
## 预测
本模型训练和预测交替进行,运行train_mmoe.py 即可得到预测结果
## 模型效果
epoch设置为100的效果如下:
```shell
epoch_id:[0],epoch_time:[136.99230 s],loss:[0.48952],train_auc_income:[0.52317],train_auc_marital:[0.78102],test_auc_income:[0.52329],test_auc_marital:[0.84055]
epoch_id:[1],epoch_time:[137.79457 s],loss:[0.48089],train_auc_income:[0.52466],train_auc_marital:[0.92589],test_auc_income:[0.52842],test_auc_marital:[0.93463]
epoch_id:[2],epoch_time:[137.22369 s],loss:[0.43654],train_auc_income:[0.63070],train_auc_marital:[0.95467],test_auc_income:[0.65807],test_auc_marital:[0.95781]
epoch_id:[3],epoch_time:[133.58558 s],loss:[0.44318],train_auc_income:[0.73284],train_auc_marital:[0.96599],test_auc_income:[0.74561],test_auc_marital:[0.96750]
epoch_id:[4],epoch_time:[128.61714 s],loss:[0.41398],train_auc_income:[0.78572],train_auc_marital:[0.97190],test_auc_income:[0.79312],test_auc_marital:[0.97280]
epoch_id:[5],epoch_time:[126.85907 s],loss:[0.44676],train_auc_income:[0.81760],train_auc_marital:[0.97549],test_auc_income:[0.82190],test_auc_marital:[0.97609]
epoch_id:[6],epoch_time:[131.20426 s],loss:[0.40833],train_auc_income:[0.83818],train_auc_marital:[0.97796],test_auc_income:[0.84132],test_auc_marital:[0.97838]
epoch_id:[7],epoch_time:[130.86647 s],loss:[0.39193],train_auc_income:[0.85259],train_auc_marital:[0.97974],test_auc_income:[0.85512],test_auc_marital:[0.98006]
epoch_id:[8],epoch_time:[137.07437 s],loss:[0.43083],train_auc_income:[0.86343],train_auc_marital:[0.98106],test_auc_income:[0.86520],test_auc_marital:[0.98126]
epoch_id:[9],epoch_time:[138.65452 s],loss:[0.38813],train_auc_income:[0.87173],train_auc_marital:[0.98205],test_auc_income:[0.87317],test_auc_marital:[0.98224]
epoch_id:[10],epoch_time:[135.61756 s],loss:[0.39048],train_auc_income:[0.87839],train_auc_marital:[0.98295],test_auc_income:[0.87954],test_auc_marital:[0.98309]
...
...
epoch_id:[95],epoch_time:[134.57041 s],loss:[0.31102],train_auc_income:[0.93345],train_auc_marital:[0.99191],test_auc_income:[0.93348],test_auc_marital:[0.99192]
epoch_id:[96],epoch_time:[134.19668 s],loss:[0.31128],train_auc_income:[0.93354],train_auc_marital:[0.99193],test_auc_income:[0.93357],test_auc_marital:[0.99193]
epoch_id:[97],epoch_time:[126.89334 s],loss:[0.31202],train_auc_income:[0.93361],train_auc_marital:[0.99195],test_auc_income:[0.93363],test_auc_marital:[0.99195]
epoch_id:[98],epoch_time:[136.01872 s],loss:[0.29857],train_auc_income:[0.93370],train_auc_marital:[0.99197],test_auc_income:[0.93372],test_auc_marital:[0.99197]
epoch_id:[99],epoch_time:[133.60402 s],loss:[0.31113],train_auc_income:[0.93379],train_auc_marital:[0.99199],test_auc_income:[0.93382],test_auc_marital:[0.99199]
```
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--expert_num", type=int, default=8, help="expert_num")
parser.add_argument("--gate_num", type=int, default=2, help="gate_num")
parser.add_argument("--epochs", type=int, default=400, help="epochs")
parser.add_argument("--batch_size", type=int, default=32, help="batch_size")
parser.add_argument(
'--use_gpu', type=bool, default=False, help='whether using gpu')
parser.add_argument(
'--train_data_path',
type=str,
default='./data/data24913/train_data/',
help="train_data_path")
parser.add_argument(
'--test_data_path',
type=str,
default='./data/data24913/test_data/',
help="test_data_path")
args = parser.parse_args()
return args
def data_preparation_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--train_path", type=str, default='', help="train_path")
parser.add_argument("--test_path", type=str, default='', help="test_path")
parser.add_argument(
'--train_data_path', type=str, default='', help="train_data_path")
parser.add_argument(
'--test_data_path', type=str, default='', help="test_data_path")
parser.add_argument(
'--validation_data_path',
type=str,
default='',
help="validation_data_path")
args = parser.parse_args()
return args
mkdir data/data24913/train_data
mkdir data/data24913/test_data
mkdir data/data24913/validation_data
train_path="data/data24913/census-income.data"
test_path="data/data24913/census-income.test"
train_data_path="data/data24913/train_data/"
test_data_path="data/data24913/test_data/"
validation_data_path="data/data24913/validation_data/"
python data_preparation.py --train_path ${train_path} \
--test_path ${test_path} \
--train_data_path ${train_data_path}\
--test_data_path ${test_data_path}\
--validation_data_path ${validation_data_path}
\ No newline at end of file
import pandas as pd
import numpy as np
import paddle.fluid as fluid
from sklearn.preprocessing import MinMaxScaler
from args import *
def fun1(x):
if x == ' 50000+.':
return 1
else:
return 0
def fun2(x):
if x == ' Never married':
return 1
else:
return 0
def data_preparation(train_path, test_path, train_data_path, test_data_path,
validation_data_path):
# The column names are from
# https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
column_names = [
'age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education',
'wage_per_hour', 'hs_college', 'marital_stat', 'major_ind_code',
'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member',
'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses',
'stock_dividends', 'tax_filer_stat', 'region_prev_res',
'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ', 'instance_weight',
'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same',
'mig_prev_sunbelt', 'num_emp', 'fam_under_18', 'country_father',
'country_mother', 'country_self', 'citizenship', 'own_or_self',
'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k'
]
# Load the dataset in Pandas
train_df = pd.read_csv(
train_path,
delimiter=',',
header=None,
index_col=None,
names=column_names)
other_df = pd.read_csv(
test_path,
delimiter=',',
header=None,
index_col=None,
names=column_names)
# First group of tasks according to the paper
label_columns = ['income_50k', 'marital_stat']
# One-hot encoding categorical columns
categorical_columns = [
'class_worker', 'det_ind_code', 'det_occ_code', 'education',
'hs_college', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin',
'sex', 'union_member', 'unemp_reason', 'full_or_part_emp',
'tax_filer_stat', 'region_prev_res', 'state_prev_res',
'det_hh_fam_stat', 'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg',
'mig_move_reg', 'mig_same', 'mig_prev_sunbelt', 'fam_under_18',
'country_father', 'country_mother', 'country_self', 'citizenship',
'vet_question'
]
train_raw_labels = train_df[label_columns]
other_raw_labels = other_df[label_columns]
transformed_train = pd.get_dummies(train_df, columns=categorical_columns)
transformed_other = pd.get_dummies(other_df, columns=categorical_columns)
# Filling the missing column in the other set
transformed_other[
'det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'] = 0
#归一化
transformed_train['income_50k'] = transformed_train['income_50k'].apply(
lambda x: fun1(x))
transformed_train['marital_stat'] = transformed_train['marital_stat'].apply(
lambda x: fun2(x))
transformed_other['income_50k'] = transformed_other['income_50k'].apply(
lambda x: fun1(x))
transformed_other['marital_stat'] = transformed_other['marital_stat'].apply(
lambda x: fun2(x))
# Split the other dataset into 1:1 validation to test according to the paper
validation_indices = transformed_other.sample(
frac=0.5, replace=False, random_state=1).index
test_indices = list(set(transformed_other.index) - set(validation_indices))
validation_data = transformed_other.iloc[validation_indices]
test_data = transformed_other.iloc[test_indices]
cols = transformed_train.columns.tolist()
cols.insert(0, cols.pop(cols.index('income_50k')))
cols.insert(0, cols.pop(cols.index('marital_stat')))
transformed_train = transformed_train[cols]
test_data = test_data[cols]
validation_data = validation_data[cols]
print(transformed_train.shape, transformed_other.shape,
validation_data.shape, test_data.shape)
transformed_train.to_csv(train_data_path + 'train_data.csv', index=False)
test_data.to_csv(test_data_path + 'test_data.csv', index=False)
validation_data.to_csv(
validation_data_path + 'validation_data.csv', index=False)
args = data_preparation_args()
data_preparation(args.train_path, args.test_path, args.train_data_path,
args.test_data_path, args.validation_data_path)
import paddle.fluid as fluid
import pandas as pd
import numpy as np
import paddle
import time
import utils
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import MinMaxScaler
from args import *
import warnings
warnings.filterwarnings("ignore")
#显示所有列
pd.set_option('display.max_columns', None)
def set_zero(var_name,
scope=fluid.global_scope(),
place=fluid.CPUPlace(),
param_type="int64"):
"""
Set tensor of a Variable to zero.
Args:
var_name(str): name of Variable
scope(Scope): Scope object, default is fluid.global_scope()
place(Place): Place object, default is fluid.CPUPlace()
param_type(str): param data type, default is int64
"""
param = scope.var(var_name).get_tensor()
param_array = np.zeros(param._get_dims()).astype(param_type)
param.set(param_array, place)
def MMOE(expert_num=8, gate_num=2):
a_data = fluid.data(name="a", shape=[-1, 499], dtype="float32")
label_income = fluid.data(
name="label_income", shape=[-1, 2], dtype="float32", lod_level=0)
label_marital = fluid.data(
name="label_marital", shape=[-1, 2], dtype="float32", lod_level=0)
# f_{i}(x) = activation(W_{i} * x + b), where activation is ReLU according to the paper
expert_outputs = []
for i in range(0, expert_num):
expert_output = fluid.layers.fc(
input=a_data,
size=4,
act='relu',
bias_attr=fluid.ParamAttr(learning_rate=1.0),
name='expert_' + str(i))
expert_outputs.append(expert_output)
expert_concat = fluid.layers.concat(expert_outputs, axis=1)
expert_concat = fluid.layers.reshape(expert_concat, [-1, 4, 8])
# g^{k}(x) = activation(W_{gk} * x + b), where activation is softmax according to the paper
gate_outputs = []
for i in range(0, gate_num):
cur_gate = fluid.layers.fc(input=a_data,
size=8,
act='softmax',
bias_attr=fluid.ParamAttr(learning_rate=1.0),
name='gate_' + str(i))
gate_outputs.append(cur_gate)
# f^{k}(x) = sum_{i=1}^{n}(g^{k}(x)_{i} * f_{i}(x))
final_outputs = []
for gate_output in gate_outputs:
expanded_gate_output = fluid.layers.reshape(gate_output, [-1, 1, 8])
weighted_expert_output = expert_concat * fluid.layers.expand(
expanded_gate_output, expand_times=[1, 4, 1])
final_outputs.append(
fluid.layers.reduce_sum(
weighted_expert_output, dim=2))
# Build tower layer from MMoE layer
output_layers = []
for index, task_layer in enumerate(final_outputs):
tower_layer = fluid.layers.fc(input=final_outputs[index],
size=8,
act='relu',
name='task_layer_' + str(index))
output_layer = fluid.layers.fc(input=tower_layer,
size=2,
act='softmax',
name='output_layer_' + str(index))
output_layers.append(output_layer)
cost_income = paddle.fluid.layers.cross_entropy(
input=output_layers[0], label=label_income, soft_label=True)
cost_marital = paddle.fluid.layers.cross_entropy(
input=output_layers[1], label=label_marital, soft_label=True)
label_income_1 = fluid.layers.slice(
label_income, axes=[1], starts=[1], ends=[2])
label_marital_1 = fluid.layers.slice(
label_marital, axes=[1], starts=[1], ends=[2])
auc_income, batch_auc_1, auc_states_1 = fluid.layers.auc(
input=output_layers[0],
label=fluid.layers.cast(
x=label_income_1, dtype='int64'))
auc_marital, batch_auc_2, auc_states_2 = fluid.layers.auc(
input=output_layers[1],
label=fluid.layers.cast(
x=label_marital_1, dtype='int64'))
avg_cost_income = fluid.layers.mean(x=cost_income)
avg_cost_marital = fluid.layers.mean(x=cost_marital)
cost = avg_cost_income + avg_cost_marital
return [
a_data, label_income, label_marital
], cost, output_layers[0], output_layers[
1], label_income, label_marital, auc_income, auc_marital, auc_states_1, auc_states_2
args = parse_args()
train_path = args.train_data_path
test_path = args.test_data_path
batch_size = args.batch_size
expert_num = args.expert_num
epochs = args.epochs
gate_num = args.gate_num
print("batch_size:[%d],expert_num:[%d],gate_num[%d]" %
(batch_size, expert_num, gate_num))
train_reader = utils.prepare_reader(train_path, batch_size)
test_reader = utils.prepare_reader(test_path, batch_size)
#for data in train_reader():
# print(data[0][1])
data_list, loss, out_1, out_2, label_1, label_2, auc_income, auc_marital, auc_states_1, auc_states_2 = MMOE(
expert_num, gate_num)
Adam = fluid.optimizer.AdamOptimizer()
Adam.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
test_program = fluid.default_main_program().clone(for_test=True)
loader = fluid.io.DataLoader.from_generator(
feed_list=data_list, capacity=batch_size, iterable=True)
loader.set_sample_list_generator(train_reader, places=place)
test_loader = fluid.io.DataLoader.from_generator(
feed_list=data_list, capacity=batch_size, iterable=True)
test_loader.set_sample_list_generator(test_reader, places=place)
mean_auc_income = []
mean_auc_marital = []
inference_scope = fluid.Scope()
for epoch in range(epochs):
for var in auc_states_1: # reset auc states
set_zero(var.name, place=place)
for var in auc_states_2: # reset auc states
set_zero(var.name, place=place)
begin = time.time()
auc_1_p = 0.0
auc_2_p = 0.0
loss_data = 0.0
for batch_id, train_data in enumerate(loader()):
loss_data, out_income, out_marital, label_income, label_marital, auc_1_p, auc_2_p = exe.run(
feed=train_data,
fetch_list=[
loss.name, out_1, out_2, label_1, label_2, auc_income,
auc_marital
],
return_numpy=True)
for var in auc_states_1: # reset auc states
set_zero(var.name, place=place)
for var in auc_states_2: # reset auc states
set_zero(var.name, place=place)
test_auc_1_p = 0.0
test_auc_2_p = 0.0
for batch_id, test_data in enumerate(test_loader()):
test_out_income, test_out_marital, test_label_income, test_label_marital, test_auc_1_p, test_auc_2_p = exe.run(
program=test_program,
feed=test_data,
fetch_list=[
out_1, out_2, label_1, label_2, auc_income, auc_marital
],
return_numpy=True)
mean_auc_income.append(test_auc_1_p)
mean_auc_marital.append(test_auc_2_p)
end = time.time()
print(
"epoch_id:[%d],epoch_time:[%.5f s],loss:[%.5f],train_auc_income:[%.5f],train_auc_marital:[%.5f],test_auc_income:[%.5f],test_auc_marital:[%.5f]"
% (epoch, end - begin, loss_data, auc_1_p, auc_2_p, test_auc_1_p,
test_auc_2_p))
print("mean_auc_income:[%.5f],mean_auc_marital[%.5f]" %
(np.mean(mean_auc_income), np.mean(mean_auc_marital)))
import random
import pandas as pd
import numpy as np
import os
import paddle.fluid as fluid
import io
from itertools import islice
from sklearn.preprocessing import MinMaxScaler
import warnings
##按行读取文件
def reader_creator(file_dir):
def reader():
files = os.listdir(file_dir)
for fi in files:
with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in islice(f, 1, None): ##忽略第一行
l = l.strip().split(',')
l = list(map(float, l))
label_income = []
label_marital = []
data = l[2:]
if int(l[1]) == 0:
label_income = [1, 0]
elif int(l[1]) == 1:
label_income = [0, 1]
if int(l[0]) == 0:
label_marital = [1, 0]
elif int(l[0]) == 1:
label_marital = [0, 1]
label_income = np.array(label_income)
label_marital = np.array(label_marital)
#label = np.array()
#label.append(label_income)
#label.append(label_marital)
yield data, label_income, label_marital
return reader
##读取一个batch
def batch_reader(reader, batch_size):
def batch_reader():
r = reader()
b = []
for instance in r:
b.append(instance)
if (len(b) == batch_size):
yield b
b = []
#if len(b) != 0:
# yield b
#
return batch_reader
##准备数据
def prepare_reader(data_path, batch_size):
data_set = reader_creator(data_path)
#random.shuffle(data_set)
return batch_reader(data_set, batch_size)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册