Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
7ba6d01d
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ba6d01d
编写于
3月 26, 2020
作者:
O
overlordmax
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mmoe and share_bottom
上级
6b882d42
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
996 addition
and
132 deletion
+996
-132
PaddleRec/multi-task/MMoE/README.md
PaddleRec/multi-task/MMoE/README.md
+93
-14
PaddleRec/multi-task/MMoE/args.py
PaddleRec/multi-task/MMoE/args.py
+31
-6
PaddleRec/multi-task/MMoE/create_data.sh
PaddleRec/multi-task/MMoE/create_data.sh
+15
-0
PaddleRec/multi-task/MMoE/data_preparation.py
PaddleRec/multi-task/MMoE/data_preparation.py
+108
-0
PaddleRec/multi-task/MMoE/mmoe_train.py
PaddleRec/multi-task/MMoE/mmoe_train.py
+172
-112
PaddleRec/multi-task/MMoE/train_cpu.sh
PaddleRec/multi-task/MMoE/train_cpu.sh
+7
-0
PaddleRec/multi-task/MMoE/train_gpu.sh
PaddleRec/multi-task/MMoE/train_gpu.sh
+7
-0
PaddleRec/multi-task/MMoE/utils.py
PaddleRec/multi-task/MMoE/utils.py
+60
-0
PaddleRec/multi-task/Share_bottom/README.md
PaddleRec/multi-task/Share_bottom/README.md
+103
-0
PaddleRec/multi-task/Share_bottom/args.py
PaddleRec/multi-task/Share_bottom/args.py
+57
-0
PaddleRec/multi-task/Share_bottom/create_data.sh
PaddleRec/multi-task/Share_bottom/create_data.sh
+15
-0
PaddleRec/multi-task/Share_bottom/data_preparation.py
PaddleRec/multi-task/Share_bottom/data_preparation.py
+108
-0
PaddleRec/multi-task/Share_bottom/share_bottom.py
PaddleRec/multi-task/Share_bottom/share_bottom.py
+153
-0
PaddleRec/multi-task/Share_bottom/train_cpu.sh
PaddleRec/multi-task/Share_bottom/train_cpu.sh
+5
-0
PaddleRec/multi-task/Share_bottom/train_gpu.sh
PaddleRec/multi-task/Share_bottom/train_gpu.sh
+5
-0
PaddleRec/multi-task/Share_bottom/utils.py
PaddleRec/multi-task/Share_bottom/utils.py
+57
-0
未找到文件。
PaddleRec/multi-task/MMoE/README.md
浏览文件 @
7ba6d01d
# MM
o
E
# MM
O
E
##简介
以下是本例的简要目录结构及说明:
MMoE是经典的多任务(multi-task)模型,原论文
[
Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts
](
https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture-
)
发表于KDD 2018.
```
├── README.md # 文档
├── mmoe_train.py # mmoe模型脚本
├── utils # 通用函数
├── args # 参数脚本
├── create_data.sh # 生成训练数据脚本
├── data_preparation.py # 数据预处理脚本
├── train_gpu.sh # gpu训练脚本
├── train_cpu.sh # cpu训练脚本
```
## 简介
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用shared-bottom的结构,不同任务间共用底部的隐层。这种结构本质上可以减少过拟合的风险,但是效果上可能受到任务差异和数据分布带来的影响。
论文中提出了一个Multi-gate Mixture-of-Experts(MMoE)的多任务学习结构。MMoE模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。(https://zhuanlan.zhihu.com/p/55752344)
多任务模型通过学习不同任务的联系和差异,可提高每个任务的学习效率和质量。多任务学习的的框架广泛采用shared-bottom的结构,不同任务间共用底部的隐层。这种结构本质上可以减少过拟合的风险,但是效果上可能受到任务差异和数据分布带来的影响。
论文
[
Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts
](
https://www.kdd.org/kdd2018/accepted-papers/view/modeling-task-relationships-in-multi-task-learning-with-multi-gate-mixture-
)
中提出了一个Multi-gate Mixture-of-Experts(MMOE)的多任务学习结构。MMOE模型刻画了任务相关性,基于共享表示来学习特定任务的函数,避免了明显增加参数的缺点。
我们
基于实际工业界场景实现了MMoE的核心思想
。
我们
在Paddlepaddle定义MMOE的网络结构,在开源数据集Census-income Data上实现和论文效果对齐。本项目支持GPU和CPU两种单机训练环境
。
## 配置
1.
6 及以上
## 数据
我们采用了随机数据作为训练数据,可以根据自己的数据调整data部分。
## 数据下载及预处理
## 训练
数据地址:
[
Census-income Data
](
https://archive.ics.uci.edu/ml/datasets/Census-Income+(KDD
)
)
数据解压后, 在create_data.sh脚本文件中添加文件的路径,并运行脚本。
```
python mmoe_train.py
mkdir data/data24913/train_data #新建训练数据目录
mkdir data/data24913/test_data #新建测试数据目录
mkdir data/data24913/validation_data #新建验证数据目录
train_path="data/data24913/census-income.data" #原始训练数据路径
test_path="data/data24913/census-income.test" #原始测试数据路径
train_data_path="data/data24913/train_data/" #处理后训练数据路径
test_data_path="data/data24913/test_data/" #处理后测试数据路径
validation_data_path="data/data24913/validation_data/" #处理后验证数据路径
python data_preparation.py --train_path ${train_path} \
--test_path ${test_path} \
--train_data_path ${train_data_path}\
--test_data_path ${test_data_path}\
--validation_data_path ${validation_data_path}
```
# 未来工作
## 环境
PaddlePaddle 1.7.0
python3.7
## 单机训练
GPU环境
在train_gpu.sh脚本文件中设置好数据路径、参数。
```
python train_mmoe.py --use_gpu True #使用gpu训练
--train_path data/data24913/train_data/ #训练数据路径
--test_path data/data24913/test_data/ #测试数据路径
--batch_size 32 #设置batch_size大小
--expert_num 8 #设置expert数量
--gate_num 2 #设置gate数量
--epochs 400 #设置epoch轮次
```
修改脚本的可执行权限并运行
```
./train_gpu.sh
```
CPU环境
在train_cpu.sh脚本文件中设置好数据路径、参数。
```
python train_mmoe.py --use_gpu False #使用cpu训练
--train_path data/data24913/train_data/ #训练数据路径
--test_path data/data24913/test_data/ #测试数据路径
--batch_size 32 #设置batch_size大小
--expert_num 8 #设置expert数量
--gate_num 2 #设置gate数量
--epochs 400 #设置epoch轮次
```
修改脚本的可执行权限并运行
```
./train_cpu.sh
```
## 预测
本模型训练和预测交替进行,运行train_mmoe.py 即可得到预测结果
## 模型效果
epoch设置为100的训练和测试效果如下:
1.
添加预测部分
![
1585193459635
](
C:\Users\overlord\AppData\Roaming\Typora\typora-user-images\1585193459635.png
)
2.
添加公开数据集的结果
PaddleRec/multi-task/MMoE/args.py
浏览文件 @
7ba6d01d
...
...
@@ -22,14 +22,39 @@ import distutils.util
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--expert_num"
,
type
=
int
,
default
=
8
,
help
=
"expert_num"
)
parser
.
add_argument
(
"--gate_num"
,
type
=
int
,
default
=
2
,
help
=
"gate_num"
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
400
,
help
=
"epochs"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"batch_size"
)
parser
.
add_argument
(
"--base_lr"
,
type
=
float
,
default
=
0.01
,
help
=
"learning_rate"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
5
,
help
=
"batch_size"
)
parser
.
add_argument
(
"--dict_dim"
,
type
=
int
,
default
=
64
,
help
=
"dict dim"
)
'--use_gpu'
,
type
=
bool
,
default
=
False
,
help
=
'whether using gpu'
)
parser
.
add_argument
(
"--emb_dim"
,
type
=
int
,
default
=
100
,
help
=
"embedding_dim"
)
'--train_data_path'
,
type
=
str
,
default
=
'./data/data24913/train_data/'
,
help
=
"train_data_path"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
False
,
help
=
'whether using gpu'
)
parser
.
add_argument
(
'--ce'
,
action
=
'store_true'
,
help
=
"run ce"
)
'--test_data_path'
,
type
=
str
,
default
=
'./data/data24913/test_data/'
,
help
=
"test_data_path"
)
args
=
parser
.
parse_args
()
return
args
def
data_preparation_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--train_path"
,
type
=
str
,
default
=
''
,
help
=
"train_path"
)
parser
.
add_argument
(
"--test_path"
,
type
=
str
,
default
=
''
,
help
=
"test_path"
)
parser
.
add_argument
(
'--train_data_path'
,
type
=
str
,
default
=
''
,
help
=
"train_data_path"
)
parser
.
add_argument
(
'--test_data_path'
,
type
=
str
,
default
=
''
,
help
=
"test_data_path"
)
parser
.
add_argument
(
'--validation_data_path'
,
type
=
str
,
default
=
''
,
help
=
"validation_data_path"
)
args
=
parser
.
parse_args
()
return
args
PaddleRec/multi-task/MMoE/create_data.sh
0 → 100644
浏览文件 @
7ba6d01d
mkdir
data/data24913/train_data
mkdir
data/data24913/test_data
mkdir
data/data24913/validation_data
train_path
=
"data/data24913/census-income.data"
test_path
=
"data/data24913/census-income.test"
train_data_path
=
"data/data24913/train_data/"
test_data_path
=
"data/data24913/test_data/"
validation_data_path
=
"data/data24913/validation_data/"
python data_preparation.py
--train_path
${
train_path
}
\
--test_path
${
test_path
}
\
--train_data_path
${
train_data_path
}
\
--test_data_path
${
test_data_path
}
\
--validation_data_path
${
validation_data_path
}
\ No newline at end of file
PaddleRec/multi-task/MMoE/data_preparation.py
0 → 100644
浏览文件 @
7ba6d01d
import
pandas
as
pd
import
numpy
as
np
import
paddle.fluid
as
fluid
from
sklearn.preprocessing
import
MinMaxScaler
from
args
import
*
def
fun1
(
x
):
if
x
==
' 50000+.'
:
return
1
else
:
return
0
def
fun2
(
x
):
if
x
==
' Never married'
:
return
1
else
:
return
0
def
data_preparation
(
train_path
,
test_path
,
train_data_path
,
test_data_path
,
validation_data_path
):
# The column names are from
# https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
column_names
=
[
'age'
,
'class_worker'
,
'det_ind_code'
,
'det_occ_code'
,
'education'
,
'wage_per_hour'
,
'hs_college'
,
'marital_stat'
,
'major_ind_code'
,
'major_occ_code'
,
'race'
,
'hisp_origin'
,
'sex'
,
'union_member'
,
'unemp_reason'
,
'full_or_part_emp'
,
'capital_gains'
,
'capital_losses'
,
'stock_dividends'
,
'tax_filer_stat'
,
'region_prev_res'
,
'state_prev_res'
,
'det_hh_fam_stat'
,
'det_hh_summ'
,
'instance_weight'
,
'mig_chg_msa'
,
'mig_chg_reg'
,
'mig_move_reg'
,
'mig_same'
,
'mig_prev_sunbelt'
,
'num_emp'
,
'fam_under_18'
,
'country_father'
,
'country_mother'
,
'country_self'
,
'citizenship'
,
'own_or_self'
,
'vet_question'
,
'vet_benefits'
,
'weeks_worked'
,
'year'
,
'income_50k'
]
# Load the dataset in Pandas
train_df
=
pd
.
read_csv
(
train_path
,
delimiter
=
','
,
header
=
None
,
index_col
=
None
,
names
=
column_names
)
other_df
=
pd
.
read_csv
(
test_path
,
delimiter
=
','
,
header
=
None
,
index_col
=
None
,
names
=
column_names
)
# First group of tasks according to the paper
label_columns
=
[
'income_50k'
,
'marital_stat'
]
# One-hot encoding categorical columns
categorical_columns
=
[
'class_worker'
,
'det_ind_code'
,
'det_occ_code'
,
'education'
,
'hs_college'
,
'major_ind_code'
,
'major_occ_code'
,
'race'
,
'hisp_origin'
,
'sex'
,
'union_member'
,
'unemp_reason'
,
'full_or_part_emp'
,
'tax_filer_stat'
,
'region_prev_res'
,
'state_prev_res'
,
'det_hh_fam_stat'
,
'det_hh_summ'
,
'mig_chg_msa'
,
'mig_chg_reg'
,
'mig_move_reg'
,
'mig_same'
,
'mig_prev_sunbelt'
,
'fam_under_18'
,
'country_father'
,
'country_mother'
,
'country_self'
,
'citizenship'
,
'vet_question'
]
train_raw_labels
=
train_df
[
label_columns
]
other_raw_labels
=
other_df
[
label_columns
]
transformed_train
=
pd
.
get_dummies
(
train_df
,
columns
=
categorical_columns
)
transformed_other
=
pd
.
get_dummies
(
other_df
,
columns
=
categorical_columns
)
# Filling the missing column in the other set
transformed_other
[
'det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'
]
=
0
#归一化
transformed_train
[
'income_50k'
]
=
transformed_train
[
'income_50k'
].
apply
(
lambda
x
:
fun1
(
x
))
transformed_train
[
'marital_stat'
]
=
transformed_train
[
'marital_stat'
].
apply
(
lambda
x
:
fun2
(
x
))
transformed_other
[
'income_50k'
]
=
transformed_other
[
'income_50k'
].
apply
(
lambda
x
:
fun1
(
x
))
transformed_other
[
'marital_stat'
]
=
transformed_other
[
'marital_stat'
].
apply
(
lambda
x
:
fun2
(
x
))
# Split the other dataset into 1:1 validation to test according to the paper
validation_indices
=
transformed_other
.
sample
(
frac
=
0.5
,
replace
=
False
,
random_state
=
1
).
index
test_indices
=
list
(
set
(
transformed_other
.
index
)
-
set
(
validation_indices
))
validation_data
=
transformed_other
.
iloc
[
validation_indices
]
test_data
=
transformed_other
.
iloc
[
test_indices
]
cols
=
transformed_train
.
columns
.
tolist
()
cols
.
insert
(
0
,
cols
.
pop
(
cols
.
index
(
'income_50k'
)))
cols
.
insert
(
0
,
cols
.
pop
(
cols
.
index
(
'marital_stat'
)))
transformed_train
=
transformed_train
[
cols
]
test_data
=
test_data
[
cols
]
validation_data
=
validation_data
[
cols
]
print
(
transformed_train
.
shape
,
transformed_other
.
shape
,
validation_data
.
shape
,
test_data
.
shape
)
transformed_train
.
to_csv
(
train_data_path
+
'train_data.csv'
,
index
=
False
)
test_data
.
to_csv
(
test_data_path
+
'test_data.csv'
,
index
=
False
)
validation_data
.
to_csv
(
validation_data_path
+
'validation_data.csv'
,
index
=
False
)
args
=
data_preparation_args
()
data_preparation
(
args
.
train_path
,
args
.
test_path
,
args
.
train_data_path
,
args
.
test_data_path
,
args
.
validation_data_path
)
PaddleRec/multi-task/MMoE/mmoe_train.py
浏览文件 @
7ba6d01d
import
paddle.fluid
as
fluid
import
pandas
as
pd
import
numpy
as
np
import
paddle
import
time
import
utils
from
sklearn.metrics
import
roc_auc_score
from
sklearn.preprocessing
import
MinMaxScaler
from
args
import
*
def
fc_layers
(
input
,
layers
,
acts
,
prefix
):
fc_layers_input
=
[
input
]
fc_layers_size
=
layers
fc_layers_act
=
acts
init_range
=
0.2
scales_tmp
=
[
input
.
shape
[
1
]]
+
fc_layers_size
scales
=
[]
for
i
in
range
(
len
(
scales_tmp
)):
scales
.
append
(
init_range
/
(
scales_tmp
[
i
]
**
0.5
))
for
i
in
range
(
len
(
fc_layers_size
)):
name
=
prefix
+
"_"
+
str
(
i
)
fc
=
fluid
.
layers
.
fc
(
input
=
fc_layers_input
[
-
1
],
size
=
fc_layers_size
[
i
],
act
=
fc_layers_act
[
i
],
param_attr
=
\
fluid
.
ParamAttr
(
learning_rate
=
1.0
,
\
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
*
scales
[
i
])),
bias_attr
=
\
fluid
.
ParamAttr
(
learning_rate
=
1.0
,
\
initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
*
scales
[
i
])),
name
=
name
)
fc_layers_input
.
append
(
fc
)
return
fc_layers_input
[
-
1
]
def
mmoe_layer
(
inputs
,
expert_num
=
8
,
gate_num
=
3
):
expert_out
=
[]
expert_nn
=
[
3
]
expert_act
=
[
'relu'
]
import
warnings
warnings
.
filterwarnings
(
"ignore"
)
#显示所有列
pd
.
set_option
(
'display.max_columns'
,
None
)
def
set_zero
(
var_name
,
scope
=
fluid
.
global_scope
(),
place
=
fluid
.
CPUPlace
(),
param_type
=
"int64"
):
"""
Set tensor of a Variable to zero.
Args:
var_name(str): name of Variable
scope(Scope): Scope object, default is fluid.global_scope()
place(Place): Place object, default is fluid.CPUPlace()
param_type(str): param data type, default is int64
"""
param
=
scope
.
var
(
var_name
).
get_tensor
()
param_array
=
np
.
zeros
(
param
.
_get_dims
()).
astype
(
param_type
)
param
.
set
(
param_array
,
place
)
def
MMOE
(
expert_num
=
8
,
gate_num
=
2
):
a_data
=
fluid
.
data
(
name
=
"a"
,
shape
=
[
-
1
,
499
],
dtype
=
"float32"
)
label_income
=
fluid
.
data
(
name
=
"label_income"
,
shape
=
[
-
1
,
2
],
dtype
=
"float32"
,
lod_level
=
0
)
label_marital
=
fluid
.
data
(
name
=
"label_marital"
,
shape
=
[
-
1
,
2
],
dtype
=
"float32"
,
lod_level
=
0
)
# f_{i}(x) = activation(W_{i} * x + b), where activation is ReLU according to the paper
expert_outputs
=
[]
for
i
in
range
(
0
,
expert_num
):
cur_expert
=
fc_layers
(
inputs
,
expert_nn
,
expert_act
,
'expert_'
+
str
(
i
))
expert_out
.
append
(
cur_expert
)
expert_concat
=
fluid
.
layers
.
concat
(
expert_out
,
axis
=
1
)
expert_concat
=
fluid
.
layers
.
reshape
(
expert_concat
,
[
-
1
,
expert_num
,
expert_nn
[
-
1
]])
outs
=
[]
expert_output
=
fluid
.
layers
.
fc
(
input
=
a_data
,
size
=
4
,
act
=
'relu'
,
bias_attr
=
fluid
.
ParamAttr
(
learning_rate
=
1.0
),
name
=
'expert_'
+
str
(
i
))
expert_outputs
.
append
(
expert_output
)
expert_concat
=
fluid
.
layers
.
concat
(
expert_outputs
,
axis
=
1
)
expert_concat
=
fluid
.
layers
.
reshape
(
expert_concat
,
[
-
1
,
4
,
8
])
# g^{k}(x) = activation(W_{gk} * x + b), where activation is softmax according to the paper
gate_outputs
=
[]
for
i
in
range
(
0
,
gate_num
):
cur_gate
=
fluid
.
layers
.
fc
(
input
=
inputs
,
size
=
expert_num
,
cur_gate
=
fluid
.
layers
.
fc
(
input
=
a_data
,
size
=
8
,
act
=
'softmax'
,
bias_attr
=
fluid
.
ParamAttr
(
learning_rate
=
1.0
),
name
=
'gate_'
+
str
(
i
))
cur_gate_expert
=
fluid
.
layers
.
elementwise_mul
(
expert_concat
,
cur_gate
,
axis
=
0
)
cur_gate_expert
=
fluid
.
layers
.
reduce_sum
(
cur_gate_expert
,
dim
=
1
)
cur_fc
=
fc_layers
(
cur_gate_expert
,
[
64
,
32
,
16
,
1
],
[
'relu'
,
'relu'
,
'relu'
,
None
],
'out_'
+
str
(
i
))
outs
.
append
(
cur_fc
)
return
outs
def
model
(
dict_dim
,
emb_dim
):
label_like
=
fluid
.
data
(
name
=
"label_like"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
0
)
label_comment
=
fluid
.
data
(
name
=
"label_comment"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
0
)
label_share
=
fluid
.
data
(
name
=
"label_share"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
,
lod_level
=
0
)
a_data
=
fluid
.
data
(
name
=
"a"
,
shape
=
[
-
1
,
1
],
dtype
=
"int64"
)
emb
=
fluid
.
layers
.
embedding
(
input
=
a_data
,
size
=
[
dict_dim
,
emb_dim
])
outs
=
mmoe_layer
(
emb
,
expert_num
=
8
,
gate_num
=
3
)
output_like
=
fluid
.
layers
.
sigmoid
(
fluid
.
layers
.
clip
(
outs
[
0
],
min
=-
15.0
,
max
=
15.0
),
name
=
"output_like"
)
output_comment
=
fluid
.
layers
.
sigmoid
(
fluid
.
layers
.
clip
(
outs
[
1
],
min
=-
15.0
,
max
=
15.0
),
name
=
"output_comment"
)
output_share
=
fluid
.
layers
.
sigmoid
(
fluid
.
layers
.
clip
(
outs
[
2
],
min
=-
15.0
,
max
=
15.0
),
name
=
"output_share"
)
cost_like
=
fluid
.
layers
.
log_loss
(
input
=
output_like
,
label
=
fluid
.
layers
.
cast
(
x
=
label_like
,
dtype
=
'float32'
))
cost_comment
=
fluid
.
layers
.
log_loss
(
input
=
output_comment
,
gate_outputs
.
append
(
cur_gate
)
# f^{k}(x) = sum_{i=1}^{n}(g^{k}(x)_{i} * f_{i}(x))
final_outputs
=
[]
for
gate_output
in
gate_outputs
:
expanded_gate_output
=
fluid
.
layers
.
reshape
(
gate_output
,
[
-
1
,
1
,
8
])
weighted_expert_output
=
expert_concat
*
fluid
.
layers
.
expand
(
expanded_gate_output
,
expand_times
=
[
1
,
4
,
1
])
final_outputs
.
append
(
fluid
.
layers
.
reduce_sum
(
weighted_expert_output
,
dim
=
2
))
# Build tower layer from MMoE layer
output_layers
=
[]
for
index
,
task_layer
in
enumerate
(
final_outputs
):
tower_layer
=
fluid
.
layers
.
fc
(
input
=
final_outputs
[
index
],
size
=
8
,
act
=
'relu'
,
name
=
'task_layer_'
+
str
(
index
))
output_layer
=
fluid
.
layers
.
fc
(
input
=
tower_layer
,
size
=
2
,
act
=
'softmax'
,
name
=
'output_layer_'
+
str
(
index
))
output_layers
.
append
(
output_layer
)
cost_income
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
output_layers
[
0
],
label
=
label_income
,
soft_label
=
True
)
cost_marital
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
output_layers
[
1
],
label
=
label_marital
,
soft_label
=
True
)
label_income_1
=
fluid
.
layers
.
slice
(
label_income
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
])
label_marital_1
=
fluid
.
layers
.
slice
(
label_marital
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
])
auc_income
,
batch_auc_1
,
auc_states_1
=
fluid
.
layers
.
auc
(
input
=
output_layers
[
0
],
label
=
fluid
.
layers
.
cast
(
x
=
label_
comment
,
dtype
=
'float32
'
))
cost_share
=
fluid
.
layers
.
log_loss
(
input
=
output_
share
,
x
=
label_
income_1
,
dtype
=
'int64
'
))
auc_marital
,
batch_auc_2
,
auc_states_2
=
fluid
.
layers
.
auc
(
input
=
output_
layers
[
1
]
,
label
=
fluid
.
layers
.
cast
(
x
=
label_
share
,
dtype
=
'float32
'
))
x
=
label_
marital_1
,
dtype
=
'int64
'
))
avg_cost_like
=
fluid
.
layers
.
mean
(
x
=
cost_like
)
avg_cost_comment
=
fluid
.
layers
.
mean
(
x
=
cost_comment
)
avg_cost_share
=
fluid
.
layers
.
mean
(
x
=
cost_share
)
avg_cost_income
=
fluid
.
layers
.
mean
(
x
=
cost_income
)
avg_cost_marital
=
fluid
.
layers
.
mean
(
x
=
cost_marital
)
cost
=
avg_cost_like
+
avg_cost_comment
+
avg_cost_share
return
cost
,
[
a_data
,
label_like
,
label_comment
,
label_share
]
cost
=
avg_cost_income
+
avg_cost_marital
return
[
a_data
,
label_income
,
label_marital
],
cost
,
output_layers
[
0
],
output_layers
[
1
],
label_income
,
label_marital
,
auc_income
,
auc_marital
,
auc_states_1
,
auc_states_2
args
=
parse_args
()
train_path
=
args
.
train_data_path
test_path
=
args
.
test_data_path
batch_size
=
args
.
batch_size
dict_dim
=
args
.
dict_dim
emb_dim
=
args
.
emb_dim
expert_num
=
args
.
expert_num
epochs
=
args
.
epochs
gate_num
=
args
.
gate_num
print
(
"batch_size:[%d],expert_num:[%d],gate_num[%d]"
%
(
batch_size
,
expert_num
,
gate_num
))
train_reader
=
utils
.
prepare_reader
(
train_path
,
batch_size
)
test_reader
=
utils
.
prepare_reader
(
test_path
,
batch_size
)
#for data in train_reader():
# print(data[0][1])
print
(
"batch_size:[%d], dict_dim:[%d], emb_dim:[%d], learning_rate:[%.4f]"
%
(
batch_size
,
dict_dim
,
emb_dim
,
args
.
base_lr
)
)
data_list
,
loss
,
out_1
,
out_2
,
label_1
,
label_2
,
auc_income
,
auc_marital
,
auc_states_1
,
auc_states_2
=
MMOE
(
expert_num
,
gate_num
)
loss
,
data_list
=
model
(
dict_dim
,
emb_dim
)
sgd
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
args
.
base_lr
)
sgd
.
minimize
(
loss
)
Adam
=
fluid
.
optimizer
.
AdamOptimizer
()
Adam
.
minimize
(
loss
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
batch_id
in
range
(
100
):
data
=
[
np
.
random
.
randint
(
2
,
size
=
(
batch_size
,
1
)).
astype
(
'int64'
)
for
i
in
range
(
4
)
]
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
data_list
,
capacity
=
batch_size
,
iterable
=
True
)
loader
.
set_sample_list_generator
(
train_reader
,
places
=
place
)
test_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
data_list
,
capacity
=
batch_size
,
iterable
=
True
)
test_loader
.
set_sample_list_generator
(
test_reader
,
places
=
place
)
mean_auc_income
=
[]
mean_auc_marital
=
[]
inference_scope
=
fluid
.
Scope
()
for
epoch
in
range
(
epochs
):
for
var
in
auc_states_1
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
for
var
in
auc_states_2
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
begin
=
time
.
time
()
loss_data
,
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"a"
:
data
[
0
],
"label_like"
:
data
[
1
],
"label_comment"
:
data
[
2
],
"label_share"
:
data
[
3
]
},
fetch_list
=
[
loss
.
name
])
auc_1_p
=
0.0
auc_2_p
=
0.0
loss_data
=
0.0
for
batch_id
,
train_data
in
enumerate
(
loader
()):
loss_data
,
out_income
,
out_marital
,
label_income
,
label_marital
,
auc_1_p
,
auc_2_p
=
exe
.
run
(
feed
=
train_data
,
fetch_list
=
[
loss
.
name
,
out_1
,
out_2
,
label_1
,
label_2
,
auc_income
,
auc_marital
],
return_numpy
=
True
)
for
var
in
auc_states_1
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
for
var
in
auc_states_2
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
test_auc_1_p
=
0.0
test_auc_2_p
=
0.0
for
batch_id
,
test_data
in
enumerate
(
test_loader
()):
test_out_income
,
test_out_marital
,
test_label_income
,
test_label_marital
,
test_auc_1_p
,
test_auc_2_p
=
exe
.
run
(
program
=
test_program
,
feed
=
test_data
,
fetch_list
=
[
out_1
,
out_2
,
label_1
,
label_2
,
auc_income
,
auc_marital
],
return_numpy
=
True
)
mean_auc_income
.
append
(
test_auc_1_p
)
mean_auc_marital
.
append
(
test_auc_2_p
)
end
=
time
.
time
()
print
(
"batch_id:[%d], loss:[%.5f], batch_time:[%.5f s]"
%
(
batch_id
,
float
(
np
.
array
(
loss_data
)),
end
-
begin
))
print
(
"epoch_id:[%d],epoch_time:[%.5f s],loss:[%.5f],train_auc_income:[%.5f],train_auc_marital:[%.5f],test_auc_income:[%.5f],test_auc_marital:[%.5f]"
%
(
epoch
,
end
-
begin
,
loss_data
,
auc_1_p
,
auc_2_p
,
test_auc_1_p
,
test_auc_2_p
))
print
(
"mean_auc_income:[%.5f],mean_auc_marital[%.5f]"
%
(
np
.
mean
(
mean_auc_income
),
np
.
mean
(
mean_auc_marital
)))
PaddleRec/multi-task/MMoE/train_cpu.sh
0 → 100644
浏览文件 @
7ba6d01d
python train_mmoe.py
--use_gpu
False
#使用cpu训练
--train_path
data/data24913/train_data/
#训练数据路径
--test_path
data/data24913/test_data/
#测试数据路径
--batch_size
32
#设置batch_size大小
--expert_num
8
#设置expert数量
--gate_num
2
#设置gate数量
--epochs
400
#设置epoch轮次
\ No newline at end of file
PaddleRec/multi-task/MMoE/train_gpu.sh
0 → 100644
浏览文件 @
7ba6d01d
python train_mmoe.py
--use_gpu
True
#使用gpu训练
--train_path
data/data24913/train_data/
#训练数据路径
--test_path
data/data24913/test_data/
#测试数据路径
--batch_size
32
#设置batch_size大小
--expert_num
8
#设置expert数量
--gate_num
2
#设置gate数量
--epochs
400
#设置epoch轮次
\ No newline at end of file
PaddleRec/multi-task/MMoE/utils.py
0 → 100644
浏览文件 @
7ba6d01d
import
random
import
pandas
as
pd
import
numpy
as
np
import
os
import
paddle.fluid
as
fluid
import
io
from
itertools
import
islice
from
sklearn.preprocessing
import
MinMaxScaler
import
warnings
##按行读取文件
def
reader_creator
(
file_dir
):
def
reader
():
files
=
os
.
listdir
(
file_dir
)
for
fi
in
files
:
with
io
.
open
(
os
.
path
.
join
(
file_dir
,
fi
),
"r"
,
encoding
=
'utf-8'
)
as
f
:
for
l
in
islice
(
f
,
1
,
None
):
##忽略第一行
l
=
l
.
strip
().
split
(
','
)
l
=
list
(
map
(
float
,
l
))
label_income
=
[]
label_marital
=
[]
data
=
l
[
2
:]
if
int
(
l
[
1
])
==
0
:
label_income
=
[
1
,
0
]
elif
int
(
l
[
1
])
==
1
:
label_income
=
[
0
,
1
]
if
int
(
l
[
0
])
==
0
:
label_marital
=
[
1
,
0
]
elif
int
(
l
[
0
])
==
1
:
label_marital
=
[
0
,
1
]
label_income
=
np
.
array
(
label_income
)
label_marital
=
np
.
array
(
label_marital
)
yield
data
,
label_income
,
label_marital
return
reader
##读取一个batch
def
batch_reader
(
reader
,
batch_size
):
def
batch_reader
():
r
=
reader
()
b
=
[]
for
instance
in
r
:
b
.
append
(
instance
)
if
(
len
(
b
)
==
batch_size
):
yield
b
b
=
[]
#if len(b) != 0:
# yield b
#
return
batch_reader
##准备数据
def
prepare_reader
(
data_path
,
batch_size
):
data_set
=
reader_creator
(
data_path
)
#random.shuffle(data_set)
return
batch_reader
(
data_set
,
batch_size
)
PaddleRec/multi-task/Share_bottom/README.md
0 → 100644
浏览文件 @
7ba6d01d
# Share_bottom
以下是本例的简要目录结构及说明:
```
├── README.md # 文档
├── share_bottom.py # mmoe模型脚本
├── utils # 通用函数
├── args # 参数脚本
├── create_data.sh # 生成训练数据脚本
├── data_preparation.py # 数据预处理脚本
├── train_gpu.sh # gpu训练脚本
├── train_cpu.sh # cpu训练脚本
```
## 简介
share_bottom是多任务学习的基本框架,其特点是对于不同的任务,底层的参数和网络结构是共享的,这种结构的优点是极大地减少网络的参数数量的情况下也能很好地对多任务进行学习,但缺点也很明显,由于底层的参数和网络结构是完全共享的,因此对于相关性不高的两个任务会导致优化冲突,从而影响模型最终的结果。后续很多Neural-based的多任务模型都是基于share_bottom发展而来的,如MMOE等模型可以改进share_bottom在多任务之间相关性低导致模型效果差的缺点。
我们在Paddlepaddle实现share_bottom网络结构,并在开源数据集Census-income Data上验证模型效果。本项目支持GPU和CPU两种单机训练环境。
## 数据下载及预处理
数据地址:
[
Census-income Data
](
https://archive.ics.uci.edu/ml/datasets/Census-Income+(KDD
)
)
数据解压后, 在create_data.sh脚本文件中添加文件的路径,并运行脚本。
```
mkdir data/data24913/train_data #新建训练数据目录
mkdir data/data24913/test_data #新建测试数据目录
mkdir data/data24913/validation_data #新建验证数据目录
train_path="data/data24913/census-income.data" #原始训练数据路径
test_path="data/data24913/census-income.test" #原始测试数据路径
train_data_path="data/data24913/train_data/" #处理后训练数据路径
test_data_path="data/data24913/test_data/" #处理后测试数据路径
validation_data_path="data/data24913/validation_data/" #处理后验证数据路径
python data_preparation.py --train_path ${train_path} \
--test_path ${test_path} \
--train_data_path ${train_data_path}\
--test_data_path ${test_data_path}\
--validation_data_path ${validation_data_path}
```
## 环境
PaddlePaddle 1.7.0
python3.7
## 单机训练
GPU环境
在train_gpu.sh脚本文件中设置好数据路径、参数。
```
python share_bottom.py --use_gpu True #使用cpu训练
--train_path data/data24913/train_data/ #训练数据路径
--test_path data/data24913/test_data/ #测试数据路径
--batch_size 32 #设置batch_size大小
--epochs 400 #设置epoch轮次
```
修改脚本的可执行权限并运行
```
./train_gpu.sh
```
CPU环境
在train_cpu.sh脚本文件中设置好数据路径、参数。
```
python share_bottom.py --use_gpu False #使用cpu训练
--train_path data/data24913/train_data/ #训练数据路径
--test_path data/data24913/test_data/ #测试数据路径
--batch_size 32 #设置batch_size大小
--epochs 400 #设置epoch轮次
```
修改脚本的可执行权限并运行
```
./train_cpu.sh
```
## 预测
本模型训练和预测交替进行,运行share_bottom.py即可得到预测结果
## 模型效果
epoch设置为100的训练和测试效果如下:
![
image-20200326122512504
](
C:\Users\overlord\AppData\Roaming\Typora\typora-user-images\image-20200326122512504.png
)
PaddleRec/multi-task/Share_bottom/args.py
0 → 100644
浏览文件 @
7ba6d01d
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
distutils.util
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
400
,
help
=
"epochs"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"batch_size"
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
bool
,
default
=
False
,
help
=
'whether using gpu'
)
parser
.
add_argument
(
'--train_data_path'
,
type
=
str
,
default
=
'./data/data24913/train_data/'
,
help
=
"train_data_path"
)
parser
.
add_argument
(
'--test_data_path'
,
type
=
str
,
default
=
'./data/data24913/test_data/'
,
help
=
"test_data_path"
)
args
=
parser
.
parse_args
()
return
args
def
data_preparation_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--train_path"
,
type
=
str
,
default
=
''
,
help
=
"train_path"
)
parser
.
add_argument
(
"--test_path"
,
type
=
str
,
default
=
''
,
help
=
"test_path"
)
parser
.
add_argument
(
'--train_data_path'
,
type
=
str
,
default
=
''
,
help
=
"train_data_path"
)
parser
.
add_argument
(
'--test_data_path'
,
type
=
str
,
default
=
''
,
help
=
"test_data_path"
)
parser
.
add_argument
(
'--validation_data_path'
,
type
=
str
,
default
=
''
,
help
=
"validation_data_path"
)
args
=
parser
.
parse_args
()
return
args
PaddleRec/multi-task/Share_bottom/create_data.sh
0 → 100644
浏览文件 @
7ba6d01d
mkdir
data/data24913/train_data
mkdir
data/data24913/test_data
mkdir
data/data24913/validation_data
train_path
=
"data/data24913/census-income.data"
test_path
=
"data/data24913/census-income.test"
train_data_path
=
"data/data24913/train_data/"
test_data_path
=
"data/data24913/test_data/"
validation_data_path
=
"data/data24913/validation_data/"
python data_preparation.py
--train_path
${
train_path
}
\
--test_path
${
test_path
}
\
--train_data_path
${
train_data_path
}
\
--test_data_path
${
test_data_path
}
\
--validation_data_path
${
validation_data_path
}
\ No newline at end of file
PaddleRec/multi-task/Share_bottom/data_preparation.py
0 → 100644
浏览文件 @
7ba6d01d
import
pandas
as
pd
import
numpy
as
np
import
paddle.fluid
as
fluid
from
sklearn.preprocessing
import
MinMaxScaler
from
args
import
*
def
fun1
(
x
):
if
x
==
' 50000+.'
:
return
1
else
:
return
0
def
fun2
(
x
):
if
x
==
' Never married'
:
return
1
else
:
return
0
def
data_preparation
(
train_path
,
test_path
,
train_data_path
,
test_data_path
,
validation_data_path
):
# The column names are from
# https://www2.1010data.com/documentationcenter/prod/Tutorials/MachineLearningExamples/CensusIncomeDataSet.html
column_names
=
[
'age'
,
'class_worker'
,
'det_ind_code'
,
'det_occ_code'
,
'education'
,
'wage_per_hour'
,
'hs_college'
,
'marital_stat'
,
'major_ind_code'
,
'major_occ_code'
,
'race'
,
'hisp_origin'
,
'sex'
,
'union_member'
,
'unemp_reason'
,
'full_or_part_emp'
,
'capital_gains'
,
'capital_losses'
,
'stock_dividends'
,
'tax_filer_stat'
,
'region_prev_res'
,
'state_prev_res'
,
'det_hh_fam_stat'
,
'det_hh_summ'
,
'instance_weight'
,
'mig_chg_msa'
,
'mig_chg_reg'
,
'mig_move_reg'
,
'mig_same'
,
'mig_prev_sunbelt'
,
'num_emp'
,
'fam_under_18'
,
'country_father'
,
'country_mother'
,
'country_self'
,
'citizenship'
,
'own_or_self'
,
'vet_question'
,
'vet_benefits'
,
'weeks_worked'
,
'year'
,
'income_50k'
]
# Load the dataset in Pandas
train_df
=
pd
.
read_csv
(
train_path
,
delimiter
=
','
,
header
=
None
,
index_col
=
None
,
names
=
column_names
)
other_df
=
pd
.
read_csv
(
test_path
,
delimiter
=
','
,
header
=
None
,
index_col
=
None
,
names
=
column_names
)
# First group of tasks according to the paper
label_columns
=
[
'income_50k'
,
'marital_stat'
]
# One-hot encoding categorical columns
categorical_columns
=
[
'class_worker'
,
'det_ind_code'
,
'det_occ_code'
,
'education'
,
'hs_college'
,
'major_ind_code'
,
'major_occ_code'
,
'race'
,
'hisp_origin'
,
'sex'
,
'union_member'
,
'unemp_reason'
,
'full_or_part_emp'
,
'tax_filer_stat'
,
'region_prev_res'
,
'state_prev_res'
,
'det_hh_fam_stat'
,
'det_hh_summ'
,
'mig_chg_msa'
,
'mig_chg_reg'
,
'mig_move_reg'
,
'mig_same'
,
'mig_prev_sunbelt'
,
'fam_under_18'
,
'country_father'
,
'country_mother'
,
'country_self'
,
'citizenship'
,
'vet_question'
]
train_raw_labels
=
train_df
[
label_columns
]
other_raw_labels
=
other_df
[
label_columns
]
transformed_train
=
pd
.
get_dummies
(
train_df
,
columns
=
categorical_columns
)
transformed_other
=
pd
.
get_dummies
(
other_df
,
columns
=
categorical_columns
)
# Filling the missing column in the other set
transformed_other
[
'det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'
]
=
0
#归一化
transformed_train
[
'income_50k'
]
=
transformed_train
[
'income_50k'
].
apply
(
lambda
x
:
fun1
(
x
))
transformed_train
[
'marital_stat'
]
=
transformed_train
[
'marital_stat'
].
apply
(
lambda
x
:
fun2
(
x
))
transformed_other
[
'income_50k'
]
=
transformed_other
[
'income_50k'
].
apply
(
lambda
x
:
fun1
(
x
))
transformed_other
[
'marital_stat'
]
=
transformed_other
[
'marital_stat'
].
apply
(
lambda
x
:
fun2
(
x
))
# Split the other dataset into 1:1 validation to test according to the paper
validation_indices
=
transformed_other
.
sample
(
frac
=
0.5
,
replace
=
False
,
random_state
=
1
).
index
test_indices
=
list
(
set
(
transformed_other
.
index
)
-
set
(
validation_indices
))
validation_data
=
transformed_other
.
iloc
[
validation_indices
]
test_data
=
transformed_other
.
iloc
[
test_indices
]
cols
=
transformed_train
.
columns
.
tolist
()
cols
.
insert
(
0
,
cols
.
pop
(
cols
.
index
(
'income_50k'
)))
cols
.
insert
(
0
,
cols
.
pop
(
cols
.
index
(
'marital_stat'
)))
transformed_train
=
transformed_train
[
cols
]
test_data
=
test_data
[
cols
]
validation_data
=
validation_data
[
cols
]
print
(
transformed_train
.
shape
,
transformed_other
.
shape
,
validation_data
.
shape
,
test_data
.
shape
)
transformed_train
.
to_csv
(
train_data_path
+
'train_data.csv'
,
index
=
False
)
test_data
.
to_csv
(
test_data_path
+
'test_data.csv'
,
index
=
False
)
validation_data
.
to_csv
(
validation_data_path
+
'validation_data.csv'
,
index
=
False
)
args
=
data_preparation_args
()
data_preparation
(
args
.
train_path
,
args
.
test_path
,
args
.
train_data_path
,
args
.
test_data_path
,
args
.
validation_data_path
)
PaddleRec/multi-task/Share_bottom/share_bottom.py
0 → 100644
浏览文件 @
7ba6d01d
import
paddle.fluid
as
fluid
import
pandas
as
pd
import
numpy
as
np
import
paddle
import
time
import
utils
from
sklearn.metrics
import
roc_auc_score
from
sklearn.preprocessing
import
MinMaxScaler
from
args
import
*
import
warnings
warnings
.
filterwarnings
(
"ignore"
)
#显示所有列
pd
.
set_option
(
'display.max_columns'
,
None
)
def
set_zero
(
var_name
,
scope
=
fluid
.
global_scope
(),
place
=
fluid
.
CPUPlace
(),
param_type
=
"int64"
):
"""
Set tensor of a Variable to zero.
Args:
var_name(str): name of Variable
scope(Scope): Scope object, default is fluid.global_scope()
place(Place): Place object, default is fluid.CPUPlace()
param_type(str): param data type, default is int64
"""
param
=
scope
.
var
(
var_name
).
get_tensor
()
param_array
=
np
.
zeros
(
param
.
_get_dims
()).
astype
(
param_type
)
param
.
set
(
param_array
,
place
)
def
share_bottom
():
a_data
=
fluid
.
data
(
name
=
"a"
,
shape
=
[
-
1
,
499
],
dtype
=
"float32"
)
label_income
=
fluid
.
data
(
name
=
"label_income"
,
shape
=
[
-
1
,
2
],
dtype
=
"float32"
,
lod_level
=
0
)
label_marital
=
fluid
.
data
(
name
=
"label_marital"
,
shape
=
[
-
1
,
2
],
dtype
=
"float32"
,
lod_level
=
0
)
#499*8*4 + 2*(4*8 + 8*2) = 16064
#16064 / (499 + 2*(8 + 8*2)) = 29
bottom_output
=
fluid
.
layers
.
fc
(
input
=
a_data
,
size
=
29
,
act
=
'relu'
,
bias_attr
=
fluid
.
ParamAttr
(
learning_rate
=
1.0
),
name
=
'bottom_output'
)
# Build tower layer from bottom layer
output_layers
=
[]
for
index
in
range
(
2
):
tower_layer
=
fluid
.
layers
.
fc
(
input
=
bottom_output
,
size
=
8
,
act
=
'relu'
,
name
=
'task_layer_'
+
str
(
index
))
output_layer
=
fluid
.
layers
.
fc
(
input
=
tower_layer
,
size
=
2
,
act
=
'softmax'
,
name
=
'output_layer_'
+
str
(
index
))
output_layers
.
append
(
output_layer
)
cost_income
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
output_layers
[
0
],
label
=
label_income
,
soft_label
=
True
)
cost_marital
=
paddle
.
fluid
.
layers
.
cross_entropy
(
input
=
output_layers
[
1
],
label
=
label_marital
,
soft_label
=
True
)
label_income_1
=
fluid
.
layers
.
slice
(
label_income
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
])
label_marital_1
=
fluid
.
layers
.
slice
(
label_marital
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
])
auc_income
,
batch_auc_1
,
auc_states_1
=
fluid
.
layers
.
auc
(
input
=
output_layers
[
0
],
label
=
fluid
.
layers
.
cast
(
x
=
label_income_1
,
dtype
=
'int64'
))
auc_marital
,
batch_auc_2
,
auc_states_2
=
fluid
.
layers
.
auc
(
input
=
output_layers
[
1
],
label
=
fluid
.
layers
.
cast
(
x
=
label_marital_1
,
dtype
=
'int64'
))
avg_cost_income
=
fluid
.
layers
.
mean
(
x
=
cost_income
)
avg_cost_marital
=
fluid
.
layers
.
mean
(
x
=
cost_marital
)
cost
=
avg_cost_income
+
avg_cost_marital
return
[
a_data
,
label_income
,
label_marital
],
cost
,
output_layers
[
0
],
output_layers
[
1
],
label_income
,
label_marital
,
auc_income
,
auc_marital
,
auc_states_1
,
auc_states_2
args
=
parse_args
()
train_path
=
args
.
train_data_path
test_path
=
args
.
test_data_path
batch_size
=
args
.
batch_size
epochs
=
args
.
epochs
print
(
"batch_size:[%d]epochs:[%d]"
%
(
batch_size
,
epochs
))
train_reader
=
utils
.
prepare_reader
(
train_path
,
batch_size
)
test_reader
=
utils
.
prepare_reader
(
test_path
,
batch_size
)
data_list
,
loss
,
out_1
,
out_2
,
label_1
,
label_2
,
auc_income
,
auc_marital
,
auc_states_1
,
auc_states_2
=
share_bottom
()
Adam
=
fluid
.
optimizer
.
AdamOptimizer
()
Adam
.
minimize
(
loss
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
test_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
data_list
,
capacity
=
batch_size
,
iterable
=
True
)
loader
.
set_sample_list_generator
(
train_reader
,
places
=
place
)
test_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
data_list
,
capacity
=
batch_size
,
iterable
=
True
)
test_loader
.
set_sample_list_generator
(
test_reader
,
places
=
place
)
mean_auc_income
=
[]
mean_auc_marital
=
[]
inference_scope
=
fluid
.
Scope
()
f
=
open
(
r
'res_share_bottom.txt'
,
'a+'
)
for
epoch
in
range
(
epochs
):
begin
=
time
.
time
()
for
var
in
auc_states_1
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
for
var
in
auc_states_2
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
begin
=
time
.
time
()
auc_1_p
=
0.0
auc_2_p
=
0.0
loss_data
=
0.0
for
batch_id
,
train_data
in
enumerate
(
loader
()):
loss_data
,
out_income
,
out_marital
,
label_income
,
label_marital
,
auc_1_p
,
auc_2_p
=
exe
.
run
(
feed
=
train_data
,
fetch_list
=
[
loss
.
name
,
out_1
,
out_2
,
label_1
,
label_2
,
auc_income
,
auc_marital
],
return_numpy
=
True
)
for
var
in
auc_states_1
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
for
var
in
auc_states_2
:
# reset auc states
set_zero
(
var
.
name
,
place
=
place
)
test_auc_1_p
=
0.0
test_auc_2_p
=
0.0
for
batch_id
,
test_data
in
enumerate
(
test_loader
()):
test_out_income
,
test_out_marital
,
test_label_income
,
test_label_marital
,
test_auc_1_p
,
test_auc_2_p
=
exe
.
run
(
program
=
test_program
,
feed
=
test_data
,
fetch_list
=
[
out_1
,
out_2
,
label_1
,
label_2
,
auc_income
,
auc_marital
],
return_numpy
=
True
)
mean_auc_income
.
append
(
test_auc_1_p
)
mean_auc_marital
.
append
(
test_auc_2_p
)
end
=
time
.
time
()
print
(
"epoch_id:[%d],epoch_time:[%.5f s],loss:[%.5f],train_auc_income:[%.5f],train_auc_marital:[%.5f],test_auc_income:[%.5f],test_auc_marital:[%.5f]"
%
(
epoch
,
end
-
begin
,
loss_data
,
auc_1_p
,
auc_2_p
,
test_auc_1_p
,
test_auc_2_p
),
file
=
f
)
print
(
"mean_auc_income:[%.5f],mean_auc_marital[%.5f]"
%
(
np
.
mean
(
mean_auc_income
),
np
.
mean
(
mean_auc_marital
)),
file
=
f
)
PaddleRec/multi-task/Share_bottom/train_cpu.sh
0 → 100644
浏览文件 @
7ba6d01d
python share_bottom.py
--use_gpu
False
#使用cpu训练
--train_path
data/data24913/train_data/
#训练数据路径
--test_path
data/data24913/test_data/
#测试数据路径
--batch_size
32
#设置batch_size大小
--epochs
400
#设置epoch轮次
\ No newline at end of file
PaddleRec/multi-task/Share_bottom/train_gpu.sh
0 → 100644
浏览文件 @
7ba6d01d
python share_bottom.py
--use_gpu
True
#使用cpu训练
--train_path
data/data24913/train_data/
#训练数据路径
--test_path
data/data24913/test_data/
#测试数据路径
--batch_size
32
#设置batch_size大小
--epochs
400
#设置epoch轮次
\ No newline at end of file
PaddleRec/multi-task/Share_bottom/utils.py
0 → 100644
浏览文件 @
7ba6d01d
import
random
import
pandas
as
pd
import
numpy
as
np
import
os
import
paddle.fluid
as
fluid
import
io
from
itertools
import
islice
from
sklearn.preprocessing
import
MinMaxScaler
import
warnings
##按行读取文件
def
reader_creator
(
file_dir
):
def
reader
():
files
=
os
.
listdir
(
file_dir
)
for
fi
in
files
:
with
io
.
open
(
os
.
path
.
join
(
file_dir
,
fi
),
"r"
,
encoding
=
'utf-8'
)
as
f
:
for
l
in
islice
(
f
,
1
,
None
):
##忽略第一行
l
=
l
.
strip
().
split
(
','
)
l
=
list
(
map
(
float
,
l
))
label_income
=
[]
label_marital
=
[]
data
=
l
[
2
:]
if
int
(
l
[
1
])
==
0
:
label_income
=
[
1
,
0
]
elif
int
(
l
[
1
])
==
1
:
label_income
=
[
0
,
1
]
if
int
(
l
[
0
])
==
0
:
label_marital
=
[
1
,
0
]
elif
int
(
l
[
0
])
==
1
:
label_marital
=
[
0
,
1
]
label_income
=
np
.
array
(
label_income
)
label_marital
=
np
.
array
(
label_marital
)
yield
data
,
label_income
,
label_marital
return
reader
##读取一个batch
def
batch_reader
(
reader
,
batch_size
):
def
batch_reader
():
r
=
reader
()
b
=
[]
for
instance
in
r
:
b
.
append
(
instance
)
if
(
len
(
b
)
==
batch_size
):
yield
b
b
=
[]
return
batch_reader
##准备数据
def
prepare_reader
(
data_path
,
batch_size
):
data_set
=
reader_creator
(
data_path
)
#random.shuffle(data_set)
return
batch_reader
(
data_set
,
batch_size
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录