提交 4d73308d 编写于 作者: H hetianjian

gnn update README and try in infer

上级 ff7c73c5
......@@ -8,8 +8,6 @@
├── train.py # 训练脚本
├── infer.py # 预测脚本
├── network.py # 网络结构
├── cluster_train.py # 多机训练
├── cluster_train.sh # 多机训练脚本
├── reader.py # 和读取数据相关的函数
├── data/
├── download.sh # 下载数据的脚本
......@@ -45,7 +43,7 @@ cd data && sh download.sh
* Step 2: 产生训练集、测试集和config文件
```
python preprocess.py
python preprocess.py --dataset diginetica
cd ..
```
运行之后在data文件夹下会产生diginetica文件夹,里面包含config.txt、test.txt train.txt三个文件
......
......@@ -5,3 +5,4 @@
unzip dataset-train-diginetica.zip "train-item-views.csv"
sed -i '1d' train-item-views.csv
sed -i '1i session_id;user_id;item_id;timeframe;eventdate' train-item-views.csv
mkdir diginetica
......@@ -50,23 +50,26 @@ def infer(epoch_num):
exe = fluid.Executor(place)
model_path = args.model_path + "epoch_" + str(epoch_num)
[infer_program, feed_names, fetch_targets] = fluid.io.load_inference_model(
model_path, exe)
feeder = fluid.DataFeeder(
feed_list=feed_names, place=place, program=infer_program)
try:
[infer_program, feed_names, fetch_targets] = fluid.io.load_inference_model(
model_path, exe)
feeder = fluid.DataFeeder(
feed_list=feed_names, place=place, program=infer_program)
loss_sum = 0.0
acc_sum = 0.0
count = 0
for data in test_data.reader(batch_size, batch_size, False):
res = exe.run(infer_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
loss_sum += res[0]
acc_sum += res[1]
count += 1
logger.info("TEST --> loss: %.4lf, Recall@20: %.4lf" %
(loss_sum / count, acc_sum / count))
loss_sum = 0.0
acc_sum = 0.0
count = 0
for data in test_data.reader(batch_size, batch_size, False):
res = exe.run(infer_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
loss_sum += res[0]
acc_sum += res[1]
count += 1
logger.info("TEST --> loss: %.4lf, Recall@20: %.4lf" %
(loss_sum / count, acc_sum / count))
except ValueError as e:
logger.info("TEST --> error: there is no model in " + model_path)
if __name__ == "__main__":
......
......@@ -23,7 +23,7 @@ class Data():
data = pickle.load(open(path, 'rb'))
self.shuffle = shuffle
self.length = len(data[0])
self.input = zip(data[0], data[1])
self.input = list(zip(data[0], data[1]))
def make_data(self, cur_batch, batch_size):
cur_batch = [list(e) for e in cur_batch]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册