未验证 提交 b8cbdd33 编写于 作者: F frankwhzhang 提交者: GitHub

Merge pull request #1384 from frankwhzhang/gru4rec

add gru4rec model
...@@ -74,14 +74,16 @@ python convert_format.py ...@@ -74,14 +74,16 @@ python convert_format.py
``` ```
## 训练 ## 训练
GPU 环境 默认配置 '--use_cuda 1' 表示使用gpu, 缺省表示使用cpu '--parallel 1' 表示使用多卡,缺省表示使用单卡
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file` 开始训练模型。
```python GPU 环境
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.file 运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file --use_cuda 1` 开始训练模型。
```
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.txt --use_cuda 1
``` ```
CPU 环境 CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。 运行命令 `python train.py train_file test_file` 开始训练模型。
```python ```
python train.py small_train.txt small_test.txt python train.py small_train.txt small_test.txt
``` ```
...@@ -100,8 +102,8 @@ python train.py small_train.txt small_test.txt ...@@ -100,8 +102,8 @@ python train.py small_train.txt small_test.txt
base_lr=0.01, # base learning rate base_lr=0.01, # base learning rate
batch_size=batch_size, batch_size=batch_size,
pass_num=10, # the number of passed for training pass_num=10, # the number of passed for training
use_cuda=True, # whether to use GPU card use_cuda=use_cuda, # whether to use GPU card
parallel=False, # whether to be parallel parallel=parallel, # whether to be parallel
model_dir="model_recall20", # directory to save model model_dir="model_recall20", # directory to save model
init_low_bound=-0.1, # uniform parameter initialization lower bound init_low_bound=-0.1, # uniform parameter initialization lower bound
init_high_bound=0.1) # uniform parameter initialization upper bound init_high_bound=0.1) # uniform parameter initialization upper bound
...@@ -198,9 +200,9 @@ model saved in model_recall20/epoch_1 ...@@ -198,9 +200,9 @@ model saved in model_recall20/epoch_1
``` ```
## 预测 ## 预测
运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如 运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测.其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
```python ```python
CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt# prediction from epoch 1 to epoch 10 small_train.txt small_test.txt CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt
``` ```
## 预测结果示例 ## 预测结果示例
......
...@@ -17,7 +17,8 @@ def parse_args(): ...@@ -17,7 +17,8 @@ def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.") parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument('train_file') parser.add_argument('train_file')
parser.add_argument('test_file') parser.add_argument('test_file')
parser.add_argument('--use_cuda', help='whether use gpu')
parser.add_argument('--parallel', help='whether parallel')
parser.add_argument( parser.add_argument(
'--enable_ce', '--enable_ce',
action='store_true', action='store_true',
...@@ -182,6 +183,9 @@ def train_net(): ...@@ -182,6 +183,9 @@ def train_net():
args = parse_args() args = parse_args()
train_file = args.train_file train_file = args.train_file
test_file = args.test_file test_file = args.test_file
use_cuda = True if args.use_cuda else False
parallel = True if args.parallel else False
print("use_cuda:", use_cuda, "parallel:", parallel)
batch_size = 50 batch_size = 50
vocab, train_reader, test_reader = utils.prepare_data( vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\ train_file, test_file,batch_size=batch_size * get_cards(args),\
...@@ -194,8 +198,8 @@ def train_net(): ...@@ -194,8 +198,8 @@ def train_net():
base_lr=0.01, base_lr=0.01,
batch_size=batch_size, batch_size=batch_size,
pass_num=10, pass_num=10,
use_cuda=True, use_cuda=use_cuda,
parallel=False, parallel=parallel,
model_dir="model_recall20", model_dir="model_recall20",
init_low_bound=-0.1, init_low_bound=-0.1,
init_high_bound=0.1) init_high_bound=0.1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册