未验证 提交 431fb9d6 编写于 作者: H hutuxian 提交者: GitHub

fix some path/download problems (#3643)

上级 541444b1
......@@ -41,6 +41,9 @@ cd data && sh data_process.sh && cd ..
pip install pandas
```
**Windows系统下请用户自行下载数据进行解压,下载链接为:[reviews_Electronics](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz)和[meta_Electronics](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gz)。**
* Step 2: 产生训练集、测试集和config文件
```
python build_dataset.py
......
......@@ -136,8 +136,8 @@ def train():
if (global_step > 400000 and global_step % PRINT_STEP == 0) or (
global_step <= 400000 and global_step % 50000 == 0):
save_dir = args.model_dir + "/global_step_" + str(
global_step)
save_dir = os.path.join(args.model_dir, "/global_step_" + str(
global_step))
feed_var_name = [
"hist_item_seq", "hist_cat_seq", "target_item",
"target_cat", "label", "mask", "target_item_seq",
......
......@@ -40,11 +40,12 @@ SR-GNN模型的介绍可以参阅论文[Session-based Recommendation with Graph
* Step 1: 运行如下命令,下载DIGINETICA数据集并进行预处理
```
cd data && sh download.sh
cd data && python download.py
```
* Step 2: 产生训练集、测试集和config文件
```
mkdir diginetica
python preprocess.py --dataset diginetica
cd ..
```
......
import requests
import sys
import time
import os
lasttime = time.time()
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
_download_file("https://sr-gnn.bj.bcebos.com/train-item-views.csv",
"./train-item-views.csv", True)
#!/bin/bash
wget --no-check-certificate https://sr-gnn.bj.bcebos.com/train-item-views.csv
mkdir diginetica
......@@ -60,7 +60,7 @@ def infer(args):
infer_program = fluid.default_main_program().clone(for_test=True)
for epoch_num in range(args.start_index, args.last_index + 1):
model_path = args.model_path + "epoch_" + str(epoch_num)
model_path = os.path.join(args.model_path, "epoch_" + str(epoch_num))
try:
if not os.path.exists(model_path):
raise ValueError()
......
......@@ -139,7 +139,7 @@ def train():
except fluid.core.EOFException:
py_reader.reset()
logger.info("epoch loss: %.4lf" % (np.mean(epoch_sum)))
save_dir = args.model_path + "/epoch_" + str(i)
save_dir = os.path.join(args.model_path, "epoch_" + str(i))
fetch_vars = [loss, acc]
fluid.io.save_inference_model(save_dir, feed_list, fetch_vars, exe)
logger.info("model saved in " + save_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册