未验证 提交 f88e3b80 编写于 作者: C ccmeteorljh 提交者: GitHub

Merge pull request #1940 from hutuxian/gnn_bug_fix

gnn update README and try in infer
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
├── train.py # 训练脚本 ├── train.py # 训练脚本
├── infer.py # 预测脚本 ├── infer.py # 预测脚本
├── network.py # 网络结构 ├── network.py # 网络结构
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── reader.py # 和读取数据相关的函数 ├── reader.py # 和读取数据相关的函数
├── data/ ├── data/
├── download.sh # 下载数据的脚本 ├── download.sh # 下载数据的脚本
...@@ -45,7 +43,7 @@ cd data && sh download.sh ...@@ -45,7 +43,7 @@ cd data && sh download.sh
* Step 2: 产生训练集、测试集和config文件 * Step 2: 产生训练集、测试集和config文件
``` ```
python preprocess.py python preprocess.py --dataset diginetica
cd .. cd ..
``` ```
运行之后在data文件夹下会产生diginetica文件夹,里面包含config.txt、test.txt train.txt三个文件 运行之后在data文件夹下会产生diginetica文件夹,里面包含config.txt、test.txt train.txt三个文件
......
...@@ -5,3 +5,4 @@ ...@@ -5,3 +5,4 @@
unzip dataset-train-diginetica.zip "train-item-views.csv" unzip dataset-train-diginetica.zip "train-item-views.csv"
sed -i '1d' train-item-views.csv sed -i '1d' train-item-views.csv
sed -i '1i session_id;user_id;item_id;timeframe;eventdate' train-item-views.csv sed -i '1i session_id;user_id;item_id;timeframe;eventdate' train-item-views.csv
mkdir diginetica
...@@ -50,6 +50,7 @@ def infer(epoch_num): ...@@ -50,6 +50,7 @@ def infer(epoch_num):
exe = fluid.Executor(place) exe = fluid.Executor(place)
model_path = args.model_path + "epoch_" + str(epoch_num) model_path = args.model_path + "epoch_" + str(epoch_num)
try:
[infer_program, feed_names, fetch_targets] = fluid.io.load_inference_model( [infer_program, feed_names, fetch_targets] = fluid.io.load_inference_model(
model_path, exe) model_path, exe)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
...@@ -67,6 +68,8 @@ def infer(epoch_num): ...@@ -67,6 +68,8 @@ def infer(epoch_num):
count += 1 count += 1
logger.info("TEST --> loss: %.4lf, Recall@20: %.4lf" % logger.info("TEST --> loss: %.4lf, Recall@20: %.4lf" %
(loss_sum / count, acc_sum / count)) (loss_sum / count, acc_sum / count))
except ValueError as e:
logger.info("TEST --> error: there is no model in " + model_path)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,7 +23,7 @@ class Data(): ...@@ -23,7 +23,7 @@ class Data():
data = pickle.load(open(path, 'rb')) data = pickle.load(open(path, 'rb'))
self.shuffle = shuffle self.shuffle = shuffle
self.length = len(data[0]) self.length = len(data[0])
self.input = zip(data[0], data[1]) self.input = list(zip(data[0], data[1]))
def make_data(self, cur_batch, batch_size): def make_data(self, cur_batch, batch_size):
cur_batch = [list(e) for e in cur_batch] cur_batch = [list(e) for e in cur_batch]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册