未验证 提交 eadb1ae9 编写于 作者: H hutuxian 提交者: GitHub

Merge pull request #1939 from hutuxian/din_bug_fix

update README and fix print_function
......@@ -36,6 +36,10 @@ DIN通过一个兴趣激活模块(Activation Unit),用预估目标Candidate AD
```
cd data && sh data_process.sh && cd ..
```
如果执行过程中遇到找不到某个包(例如pandas包)的报错,使用如下命令安装对应的包即可。
```
pip install pandas
```
* Step 2: 产生训练集、测试集和config文件
```
......@@ -103,6 +107,12 @@ model saved in din_amazon/global_step_50000
...
```
提示:
* 在单机条件下,使用代码中默认的超参数运行时,产生最优auc的global step大致在440000到500000之间
* 训练超出一定的epoch后会稍稍出现过拟合
## 预测
参考如下命令,开始预测.
......
from __future__ import print_function
import random
import pickle
from __future__ import print_function
random.seed(1234)
......
from __future__ import print_function
import pickle
import pandas as pd
......@@ -13,10 +14,12 @@ def to_df(file_path):
return df
print("start to analyse reviews_Electronics_5.json")
reviews_df = to_df('./raw_data/reviews_Electronics_5.json')
with open('./raw_data/reviews.pkl', 'wb') as f:
pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)
print("start to analyse meta_Electronics.json")
meta_df = to_df('./raw_data/meta_Electronics.json')
meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]
meta_df = meta_df.reset_index(drop=True)
......
from __future__ import print_function
import random
import pickle
import numpy as np
from __future__ import print_function
random.seed(1234)
......
......@@ -121,7 +121,7 @@ def train():
loss_sum = 0.0
if (global_step > 400000 and global_step % PRINT_STEP == 0) or (
global_step < 400000 and global_step % 50000 == 0):
global_step <= 400000 and global_step % 50000 == 0):
save_dir = args.model_dir + "/global_step_" + str(
global_step)
feed_var_name = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册