提交 d0067d1f 编写于 作者: H hetianjian

update README and fix print_function

上级 ff7c73c5
...@@ -36,6 +36,10 @@ DIN通过一个兴趣激活模块(Activation Unit),用预估目标Candidate AD ...@@ -36,6 +36,10 @@ DIN通过一个兴趣激活模块(Activation Unit),用预估目标Candidate AD
``` ```
cd data && sh data_process.sh && cd .. cd data && sh data_process.sh && cd ..
``` ```
如果执行过程中遇到找不到某个包(例如pandas包)的报错,使用如下命令安装对应的包即可。
```
pip install pandas
```
* Step 2: 产生训练集、测试集和config文件 * Step 2: 产生训练集、测试集和config文件
``` ```
...@@ -103,6 +107,12 @@ model saved in din_amazon/global_step_50000 ...@@ -103,6 +107,12 @@ model saved in din_amazon/global_step_50000
... ...
``` ```
提示:
* 在单机条件下,使用代码中默认的超参数运行时,产生最优auc的global step大致在440000到500000之间
* 训练超出一定的epoch后会稍稍出现过拟合
## 预测 ## 预测
参考如下命令,开始预测. 参考如下命令,开始预测.
......
from __future__ import print_function
import random import random
import pickle import pickle
from __future__ import print_function
random.seed(1234) random.seed(1234)
......
from __future__ import print_function
import pickle import pickle
import pandas as pd import pandas as pd
...@@ -13,10 +14,12 @@ def to_df(file_path): ...@@ -13,10 +14,12 @@ def to_df(file_path):
return df return df
print("start to analyse reviews_Electronics_5.json")
reviews_df = to_df('./raw_data/reviews_Electronics_5.json') reviews_df = to_df('./raw_data/reviews_Electronics_5.json')
with open('./raw_data/reviews.pkl', 'wb') as f: with open('./raw_data/reviews.pkl', 'wb') as f:
pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)
print("start to analyse meta_Electronics.json")
meta_df = to_df('./raw_data/meta_Electronics.json') meta_df = to_df('./raw_data/meta_Electronics.json')
meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())] meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]
meta_df = meta_df.reset_index(drop=True) meta_df = meta_df.reset_index(drop=True)
......
from __future__ import print_function
import random import random
import pickle import pickle
import numpy as np import numpy as np
from __future__ import print_function
random.seed(1234) random.seed(1234)
......
...@@ -121,7 +121,7 @@ def train(): ...@@ -121,7 +121,7 @@ def train():
loss_sum = 0.0 loss_sum = 0.0
if (global_step > 400000 and global_step % PRINT_STEP == 0) or ( if (global_step > 400000 and global_step % PRINT_STEP == 0) or (
global_step < 400000 and global_step % 50000 == 0): global_step <= 400000 and global_step % 50000 == 0):
save_dir = args.model_dir + "/global_step_" + str( save_dir = args.model_dir + "/global_step_" + str(
global_step) global_step)
feed_var_name = [ feed_var_name = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册