未验证 提交 6fd31ee2 编写于 作者: H hutuxian 提交者: GitHub

Deep Interest Network (#1781)

* deep interest network
上级 7f4909c6
# DIN
以下是本例的简要目录结构及说明:
```text
.
├── README.md # 文档
├── train.py # 训练脚本
├── infer.py # 预测脚本
├── network.py # 网络结构
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── reader.py # 和读取数据相关的函数
├── data/
├── build_dataset.py # 文本数据转化为paddle数据
├── convert_pd.py # 将原始数据转化为pandas的dataframe
├── data_process.sh # 数据预处理脚本
├── remap_id.py # remap类别id
```
## 简介
DIN模型的介绍可以参阅论文[Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/abs/1706.06978)
DIN通过一个兴趣激活模块(Activation Unit),用预估目标Candidate ADs的信息去激活用户的历史点击商品,以此提取用户与当前预估目标相关的兴趣。
权重高的历史行为表明这部分兴趣和当前广告相关,权重低的则是和广告无关的”兴趣噪声“。我们通过将激活的商品和激活权重相乘,然后累加起来作为当前预估目标ADs相关的兴趣状态表达。
最后我们将这相关的用户兴趣表达、用户静态特征和上下文相关特征,以及ad相关的特征拼接起来,输入到后续的多层DNN网络,最后预测得到用户对当前目标ADs的点击概率。
## 数据下载及预处理
* Step 1: 运行如下命令 下载[Amazon Product数据集](http://jmcauley.ucsd.edu/data/amazon/)并进行预处理
```
cd data && sh data_process.sh && cd ..
```
* Step 2: 产生训练集、测试集和config文件
```
python build_dataset.py
```
运行之后在data文件夹下会产生config.txt、paddle_test.txt、paddle_train.txt三个文件
数据格式例子如下:
```
3737 19450;288 196;18486;674;1
3647 4342 6855 3805;281 463 558 674;4206;463;1
1805 4309;87 87;21354;556;1
18209 20753;649 241;51924;610;0
13150;351;41455;792;1
35120 40418;157 714;52035;724;0
```
其中每一行是一个Sample,由分号分隔的5个域组成。前两个域是历史交互的item序列和item对应的类别,第三、四个域是待预测的item和其类别,最后一个域是label,表示点击与否。
## 训练
具体的参数配置说明可通过运行下列代码查看
```
python train.py -h
```
gpu 单机单卡训练
``` bash
CUDA_VISIBLE_DEVICES=1 python -u train.py --config_path 'data/config.txt' --train_dir 'data/paddle_train.txt' --batch_size 32 --epoch_num 100 --use_cuda 1 > log.txt 2>&1 &
```
cpu 单机训练
``` bash
python -u train.py --config_path 'data/config.txt' --train_dir 'data/paddle_train.txt' --batch_size 32 --epoch_num 100 --use_cuda 0 > log.txt 2>&1 &
```
值得注意的是上述单卡训练可以通过加--parallel 1参数使用Parallel Executor来进行加速
gpu 单机多卡训练
``` bash
CUDA_VISIBLE_DEVICES=0,1 python -u train.py --config_path 'data/config.txt' --train_dir 'data/paddle_train.txt' --batch_size 32 --epoch_num 100 --use_cuda 1 --parallel 1 --num_devices 2 > log.txt 2>&1 &
```
cpu 单机多卡训练
``` bash
CPU_NUM=10 python -u train.py --config_path 'data/config.txt' --train_dir 'data/paddle_train.txt' --batch_size 32 --epoch_num 100 --use_cuda 0 --parallel 1 --num_devices 10 > log.txt 2>&1 &
```
## 训练结果示例
我们在Tesla K40m单GPU卡上训练的日志如下所示(以实际输出为准)
```text
2019-02-22 09:31:51,578 - INFO - reading data begins
2019-02-22 09:32:22,407 - INFO - reading data completes
W0222 09:32:24.151955 7221 device_context.cc:263] Please NOTE: device: 0, CUDA Capability: 35, Driver API Version: 9.0, Runtime API Version: 8.0
W0222 09:32:24.152046 7221 device_context.cc:271] device: 0, cuDNN Version: 7.0.
2019-02-22 09:32:27,797 - INFO - train begins
epoch: 1 global_step: 1000 train_loss: 0.6950 time: 14.64
epoch: 1 global_step: 2000 train_loss: 0.6854 time: 15.41
epoch: 1 global_step: 3000 train_loss: 0.6799 time: 14.84
...
model saved in din_amazon/global_step_50000
...
```
## 预测
参考如下命令,开始预测.
其中model_path为模型的路径,test_path为测试数据路径。
```
CUDA_VISIBLE_DEVICES=3 python infer.py --model_path 'din_amazon/global_step_400000' --test_path 'data/paddle_test.txt' --use_cuda 1
```
## 预测结果示例
```text
2019-02-22 11:22:58,804 - INFO - TEST --> loss: [0.47005194] auc:0.863794952818
```
## 多机训练
可参考cluster_train.py 配置多机环境
运行命令本地模拟多机场景
```
sh cluster_train.sh
```
import sys
import logging
import time
import numpy as np
import argparse
import paddle.fluid as fluid
import paddle
import time
import network
import reader
import random
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("din")
parser.add_argument(
'--config_path',
type=str,
default='data/config.txt',
help='dir of config')
parser.add_argument(
'--train_dir',
type=str,
default='data/paddle_train.txt',
help='dir of train file')
parser.add_argument(
'--model_dir',
type=str,
default='din_amazon/',
help='dir of saved model')
parser.add_argument(
'--batch_size', type=int, default=16, help='number of batch size')
parser.add_argument(
'--epoch_num', type=int, default=200, help='number of epoch')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether to use gpu')
parser.add_argument(
'--parallel',
type=int,
default=0,
help='whether to use parallel executor')
parser.add_argument(
'--base_lr', type=float, default=0.85, help='based learning rate')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def train():
args = parse_args()
config_path = args.config_path
train_path = args.train_dir
epoch_num = args.epoch_num
use_cuda = True if args.use_cuda else False
use_parallel = True if args.parallel else False
logger.info("reading data begins")
user_count, item_count, cat_count = reader.config_read(config_path)
#data_reader, max_len = reader.prepare_reader(train_path, args.batch_size)
logger.info("reading data completes")
avg_cost, pred = network.network(item_count, cat_count, 433)
#fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
base_lr = args.base_lr
boundaries = [410000]
values = [base_lr, 0.2]
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.piecewise_decay(
boundaries=boundaries, values=values))
sgd_optimizer.minimize(avg_cost)
def train_loop(main_program):
data_reader, max_len = reader.prepare_reader(train_path,
args.batch_size)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feeder = fluid.DataFeeder(
feed_list=[
"hist_item_seq", "hist_cat_seq", "target_item", "target_cat",
"label", "mask", "target_item_seq", "target_cat_seq"
],
place=place)
if use_parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
loss_name=avg_cost.name,
main_program=main_program)
else:
train_exe = exe
logger.info("train begins")
global_step = 0
PRINT_STEP = 1000
start_time = time.time()
loss_sum = 0.0
for id in range(epoch_num):
epoch = id + 1
for data in data_reader():
global_step += 1
results = train_exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_cost.name, pred.name],
return_numpy=True)
loss_sum += results[0].mean()
if global_step % PRINT_STEP == 0:
logger.info(
"epoch: %d\tglobal_step: %d\ttrain_loss: %.4f\t\ttime: %.2f"
% (epoch, global_step, loss_sum / PRINT_STEP,
time.time() - start_time))
start_time = time.time()
loss_sum = 0.0
if (global_step > 400000 and
global_step % PRINT_STEP == 0) or (
global_step < 400000 and
global_step % 50000 == 0):
save_dir = args.model_dir + "/global_step_" + str(
global_step)
feed_var_name = [
"hist_item_seq", "hist_cat_seq", "target_item",
"target_cat", "label", "mask", "target_item_seq",
"target_cat_seq"
]
fetch_vars = [avg_cost, pred]
fluid.io.save_inference_model(save_dir, feed_var_name,
fetch_vars, exe)
train_exe.close()
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
logger.info("run psever")
prog, startup = t.get_pserver_programs(args.current_endpoint)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
logger.info("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
train()
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
python -u cluster_train.py \
--config_path 'data/config.txt' \
--train_dir 'data/paddle_train.txt' \
--batch_size 32 \
--epoch_num 100 \
--use_cuda 0 \
--parallel 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
import random
import pickle
from __future__ import print_function
random.seed(1234)
print("read and process data")
with open('./raw_data/remap.pkl', 'rb') as f:
reviews_df = pickle.load(f)
cate_list = pickle.load(f)
user_count, item_count, cate_count, example_count = pickle.load(f)
train_set = []
test_set = []
for reviewerID, hist in reviews_df.groupby('reviewerID'):
pos_list = hist['asin'].tolist()
def gen_neg():
neg = pos_list[0]
while neg in pos_list:
neg = random.randint(0, item_count - 1)
return neg
neg_list = [gen_neg() for i in range(len(pos_list))]
for i in range(1, len(pos_list)):
hist = pos_list[:i]
if i != len(pos_list) - 1:
train_set.append((reviewerID, hist, pos_list[i], 1))
train_set.append((reviewerID, hist, neg_list[i], 0))
else:
label = (pos_list[i], neg_list[i])
test_set.append((reviewerID, hist, label))
random.shuffle(train_set)
random.shuffle(test_set)
assert len(test_set) == user_count
def print_to_file(data, fout):
for i in range(len(data)):
fout.write(str(data[i]))
if i != len(data) - 1:
fout.write(' ')
else:
fout.write(';')
print("make train data")
with open("paddle_train.txt", "w") as fout:
for line in train_set:
history = line[1]
target = line[2]
label = line[3]
cate = [cate_list[x] for x in history]
print_to_file(history, fout)
print_to_file(cate, fout)
fout.write(str(target) + ";")
fout.write(str(cate_list[target]) + ";")
fout.write(str(label) + "\n")
print("make test data")
with open("paddle_test.txt", "w") as fout:
for line in test_set:
history = line[1]
target = line[2]
cate = [cate_list[x] for x in history]
print_to_file(history, fout)
print_to_file(cate, fout)
fout.write(str(target[0]) + ";")
fout.write(str(cate_list[target[0]]) + ";")
fout.write("1\n")
print_to_file(history, fout)
print_to_file(cate, fout)
fout.write(str(target[1]) + ";")
fout.write(str(cate_list[target[1]]) + ";")
fout.write("0\n")
print("make config data")
with open('config.txt', 'w') as f:
f.write(str(user_count) + "\n")
f.write(str(item_count) + "\n")
f.write(str(cate_count) + "\n")
import pickle
import pandas as pd
def to_df(file_path):
with open(file_path, 'r') as fin:
df = {}
i = 0
for line in fin:
df[i] = eval(line)
i += 1
df = pd.DataFrame.from_dict(df, orient='index')
return df
reviews_df = to_df('./raw_data/reviews_Electronics_5.json')
with open('./raw_data/reviews.pkl', 'wb') as f:
pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)
meta_df = to_df('./raw_data/meta_Electronics.json')
meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]
meta_df = meta_df.reset_index(drop=True)
with open('./raw_data/meta.pkl', 'wb') as f:
pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)
#! /bin/bash
echo "begin download data"
mkdir raw_data && cd raw_data
wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz
gzip -d reviews_Electronics_5.json.gz
wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gz
gzip -d meta_Electronics.json.gz
echo "download data successful"
cd ..
python convert_pd.py
python remap_id.py
import random
import pickle
import numpy as np
from __future__ import print_function
random.seed(1234)
with open('./raw_data/reviews.pkl', 'rb') as f:
reviews_df = pickle.load(f)
reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]
with open('./raw_data/meta.pkl', 'rb') as f:
meta_df = pickle.load(f)
meta_df = meta_df[['asin', 'categories']]
meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1])
def build_map(df, col_name):
key = sorted(df[col_name].unique().tolist())
m = dict(zip(key, range(len(key))))
df[col_name] = df[col_name].map(lambda x: m[x])
return m, key
asin_map, asin_key = build_map(meta_df, 'asin')
cate_map, cate_key = build_map(meta_df, 'categories')
revi_map, revi_key = build_map(reviews_df, 'reviewerID')
user_count, item_count, cate_count, example_count =\
len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0]
print('user_count: %d\titem_count: %d\tcate_count: %d\texample_count: %d' %
(user_count, item_count, cate_count, example_count))
meta_df = meta_df.sort_values('asin')
meta_df = meta_df.reset_index(drop=True)
reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x])
reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime'])
reviews_df = reviews_df.reset_index(drop=True)
reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]
cate_list = [meta_df['categories'][i] for i in range(len(asin_map))]
cate_list = np.array(cate_list, dtype=np.int32)
with open('./raw_data/remap.pkl', 'wb') as f:
pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid
pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line
pickle.dump((user_count, item_count, cate_count, example_count), f,
pickle.HIGHEST_PROTOCOL)
pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import logging
import numpy as np
import os
import paddle
import paddle.fluid as fluid
import reader
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle DIN example")
parser.add_argument(
'--model_path', type=str, required=True, help="path of model parameters")
parser.add_argument(
'--test_path', type=str, default='data/paddle_test.txt.bak', help='dir of test file')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether to use gpu')
return parser.parse_args()
def calc_auc(raw_arr):
# sort by pred value, from small to big
arr = sorted(raw_arr, key=lambda d: d[2])
auc = 0.0
fp1, tp1, fp2, tp2 = 0.0, 0.0, 0.0, 0.0
for record in arr:
fp2 += record[0] # noclick
tp2 += record[1] # click
auc += (fp2 - fp1) * (tp2 + tp1)
fp1, tp1 = fp2, tp2
# if all nonclick or click, disgard
threshold = len(arr) - 1e-3
if tp2 > threshold or fp2 > threshold:
return -0.5
if tp2 * fp2 > 0.0: # normal auc
return (1.0 - auc / (2.0 * tp2 * fp2))
else:
return None
def infer():
args = parse_args()
model_path = args.model_path
use_cuda = True if args.use_cuda else False
data_reader, _ = reader.prepare_reader(args.test_path, 32 * 16)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inference_scope = fluid.core.Scope()
exe = fluid.Executor(place)
#with fluid.scope_guard(inference_scope):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
feeder = fluid.DataFeeder(
feed_list=feed_target_names, place=place, program=inference_program)
loss_sum = 0.0
score = []
count = 0
for data in data_reader():
res = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
loss_sum += res[0]
for i in range(len(data)):
if data[i][4] > 0.5:
score.append([0, 1, res[1][i]])
else:
score.append([1, 0, res[1][i]])
count += 1
auc = calc_auc(score)
logger.info("TEST --> loss: {}, auc: {}".format(loss_sum / count, auc))
if __name__ == '__main__':
infer()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
def din_attention(hist, target_expand, max_len, mask):
"""activation weight"""
hidden_size = hist.shape[-1]
concat = fluid.layers.concat(
[hist, target_expand, hist - target_expand, hist * target_expand],
axis=2)
atten_fc1 = fluid.layers.fc(name="atten_fc1",
input=concat,
size=80,
act="sigmoid",
num_flatten_dims=2)
atten_fc2 = fluid.layers.fc(name="atten_fc2",
input=atten_fc1,
size=40,
act="sigmoid",
num_flatten_dims=2)
atten_fc3 = fluid.layers.fc(name="atten_fc3",
input=atten_fc2,
size=1,
num_flatten_dims=2)
atten_fc3 += mask
atten_fc3 = fluid.layers.transpose(x=atten_fc3, perm=[0, 2, 1])
atten_fc3 = fluid.layers.scale(x=atten_fc3, scale=hidden_size**-0.5)
weight = fluid.layers.softmax(atten_fc3)
out = fluid.layers.matmul(weight, hist)
out = fluid.layers.reshape(x=out, shape=[0, hidden_size])
return out
def network(item_count, cat_count, max_len):
"""network definition"""
item_emb_size = 64
cat_emb_size = 64
is_sparse = False
#significant for speeding up the training process
item_emb_attr = fluid.ParamAttr(name="item_emb")
cat_emb_attr = fluid.ParamAttr(name="cat_emb")
hist_item_seq = fluid.layers.data(
name="hist_item_seq", shape=[max_len, 1], dtype="int64")
hist_cat_seq = fluid.layers.data(
name="hist_cat_seq", shape=[max_len, 1], dtype="int64")
target_item = fluid.layers.data(
name="target_item", shape=[1], dtype="int64")
target_cat = fluid.layers.data(
name="target_cat", shape=[1], dtype="int64")
label = fluid.layers.data(
name="label", shape=[1], dtype="float32")
mask = fluid.layers.data(
name="mask", shape=[max_len, 1], dtype="float32")
target_item_seq = fluid.layers.data(
name="target_item_seq", shape=[max_len, 1], dtype="int64")
target_cat_seq = fluid.layers.data(
name="target_cat_seq", shape=[max_len, 1], dtype="int64", lod_level=0)
hist_item_emb = fluid.layers.embedding(
input=hist_item_seq,
size=[item_count, item_emb_size],
param_attr=item_emb_attr,
is_sparse=is_sparse)
hist_cat_emb = fluid.layers.embedding(
input=hist_cat_seq,
size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=is_sparse)
target_item_emb = fluid.layers.embedding(
input=target_item,
size=[item_count, item_emb_size],
param_attr=item_emb_attr,
is_sparse=is_sparse)
target_cat_emb = fluid.layers.embedding(
input=target_cat,
size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=is_sparse)
target_item_seq_emb = fluid.layers.embedding(
input=target_item_seq,
size=[item_count, item_emb_size],
param_attr=item_emb_attr,
is_sparse=is_sparse)
target_cat_seq_emb = fluid.layers.embedding(
input=target_cat_seq,
size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=is_sparse)
item_b = fluid.layers.embedding(
input=target_item,
size=[item_count, 1],
param_attr=fluid.initializer.Constant(value=0.0))
hist_seq_concat = fluid.layers.concat([hist_item_emb, hist_cat_emb], axis=2)
target_seq_concat = fluid.layers.concat(
[target_item_seq_emb, target_cat_seq_emb], axis=2)
target_concat = fluid.layers.concat(
[target_item_emb, target_cat_emb], axis=1)
out = din_attention(hist_seq_concat, target_seq_concat, max_len, mask)
out_fc = fluid.layers.fc(name="out_fc",
input=out,
size=item_emb_size + cat_emb_size,
num_flatten_dims=1)
embedding_concat = fluid.layers.concat([out_fc, target_concat], axis=1)
fc1 = fluid.layers.fc(name="fc1",
input=embedding_concat,
size=80,
act="sigmoid")
fc2 = fluid.layers.fc(name="fc2", input=fc1, size=40, act="sigmoid")
fc3 = fluid.layers.fc(name="fc3", input=fc2, size=1)
logit = fc3 + item_b
loss = fluid.layers.sigmoid_cross_entropy_with_logits(x=logit, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss, fluid.layers.sigmoid(logit)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import numpy as np
import paddle
import pickle
def pad_batch_data(input, max_len):
res = np.array([x + [0] * (max_len - len(x)) for x in input])
res = res.astype("int64").reshape([-1, max_len, 1])
return res
def make_data(b):
max_len = max(len(x[0]) for x in b)
item = pad_batch_data([x[0] for x in b], max_len)
cat = pad_batch_data([x[1] for x in b], max_len)
len_array = [len(x[0]) for x in b]
mask = np.array(
[[0] * x + [-1e9] * (max_len - x) for x in len_array]).reshape(
[-1, max_len, 1])
target_item_seq = np.array(
[[x[2]] * max_len for x in b]).astype("int64").reshape(
[-1, max_len, 1])
target_cat_seq = np.array(
[[x[3]] * max_len for x in b]).astype("int64").reshape(
[-1, max_len, 1])
res = []
for i in range(len(b)):
res.append([
item[i], cat[i], b[i][2], b[i][3], b[i][4], mask[i],
target_item_seq[i], target_cat_seq[i]
])
return res
def batch_reader(reader, batch_size, group_size):
def batch_reader():
bg = []
for line in reader:
bg.append(line)
if len(bg) == group_size:
sortb = sorted(bg, key=lambda x: len(x[0]), reverse=False)
bg = []
for i in range(0, group_size, batch_size):
b = sortb[i:i + batch_size]
yield make_data(b)
len_bg = len(bg)
if len_bg != 0:
sortb = sorted(bg, key=lambda x: len(x[0]), reverse=False)
bg = []
remain = len_bg % batch_size
for i in range(0, len_bg - remain, batch_size):
b = sortb[i:i + batch_size]
yield make_data(b)
return batch_reader
def base_read(file_dir):
res = []
max_len = 0
with open(file_dir, "r") as fin:
for line in fin:
line = line.strip().split(';')
hist = line[0].split()
cate = line[1].split()
max_len = max(max_len, len(hist))
res.append([hist, cate, line[2], line[3], float(line[4])])
return res, max_len
def prepare_reader(data_path, bs):
data_set, max_len = base_read(data_path)
random.shuffle(data_set)
return batch_reader(data_set, bs, bs * 20), max_len
def config_read(config_path):
with open(config_path, "r") as fin:
user_count = int(fin.readline().strip())
item_count = int(fin.readline().strip())
cat_count = int(fin.readline().strip())
return user_count, item_count, cat_count
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
import logging
import time
import numpy as np
import argparse
import paddle.fluid as fluid
import paddle
import time
import network
import reader
import random
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("din")
parser.add_argument(
'--config_path', type=str, default='data/config.txt', help='dir of config')
parser.add_argument(
'--train_dir', type=str, default='data/paddle_train.txt', help='dir of train file')
parser.add_argument(
'--model_dir', type=str, default='din_amazon', help='dir of saved model')
parser.add_argument(
'--batch_size', type=int, default=16, help='number of batch size')
parser.add_argument(
'--epoch_num', type=int, default=200, help='number of epoch')
parser.add_argument(
'--use_cuda', type=int, default=0, help='whether to use gpu')
parser.add_argument(
'--parallel', type=int, default=0, help='whether to use parallel executor')
parser.add_argument(
'--base_lr', type=float, default=0.85, help='based learning rate')
parser.add_argument(
'--num_devices', type=int, default=1, help='Number of GPU devices')
args = parser.parse_args()
return args
def train():
args = parse_args()
config_path = args.config_path
train_path = args.train_dir
epoch_num = args.epoch_num
use_cuda = True if args.use_cuda else False
use_parallel = True if args.parallel else False
logger.info("reading data begins")
user_count, item_count, cat_count = reader.config_read(config_path)
data_reader, max_len = reader.prepare_reader(train_path, args.batch_size *
args.num_devices)
logger.info("reading data completes")
avg_cost, pred = network.network(item_count, cat_count, max_len)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=5.0))
base_lr = args.base_lr
boundaries = [410000]
values = [base_lr, 0.2]
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.piecewise_decay(
boundaries=boundaries, values=values))
sgd_optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feeder = fluid.DataFeeder(
feed_list=[
"hist_item_seq", "hist_cat_seq", "target_item", "target_cat",
"label", "mask", "target_item_seq", "target_cat_seq"
],
place=place)
if use_parallel:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=avg_cost.name)
else:
train_exe = exe
logger.info("train begins")
global_step = 0
PRINT_STEP = 1000
start_time = time.time()
loss_sum = 0.0
for id in range(epoch_num):
epoch = id + 1
for data in data_reader():
global_step += 1
results = train_exe.run(feed=feeder.feed(data),
fetch_list=[avg_cost.name, pred.name],
return_numpy=True)
loss_sum += results[0].mean()
if global_step % PRINT_STEP == 0:
logger.info(
"epoch: %d\tglobal_step: %d\ttrain_loss: %.4f\t\ttime: %.2f"
% (epoch, global_step, loss_sum / PRINT_STEP,
time.time() - start_time))
start_time = time.time()
loss_sum = 0.0
if (global_step > 400000 and global_step % PRINT_STEP == 0) or (
global_step < 400000 and global_step % 50000 == 0):
save_dir = args.model_dir + "/global_step_" + str(
global_step)
feed_var_name = [
"hist_item_seq", "hist_cat_seq", "target_item",
"target_cat", "label", "mask", "target_item_seq",
"target_cat_seq"
]
fetch_vars = [avg_cost, pred]
fluid.io.save_inference_model(save_dir, feed_var_name,
fetch_vars, exe)
logger.info("model saved in " + save_dir)
if __name__ == "__main__":
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册