未验证 提交 b8cbdd33 编写于 作者: F frankwhzhang 提交者: GitHub

Merge pull request #1384 from frankwhzhang/gru4rec

add gru4rec model
......@@ -74,14 +74,16 @@ python convert_format.py
```
## 训练
GPU 环境 默认配置
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file` 开始训练模型。
```python
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.file
'--use_cuda 1' 表示使用gpu, 缺省表示使用cpu '--parallel 1' 表示使用多卡,缺省表示使用单卡
GPU 环境
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file --use_cuda 1` 开始训练模型。
```
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.txt --use_cuda 1
```
CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。
```python
```
python train.py small_train.txt small_test.txt
```
......@@ -100,8 +102,8 @@ python train.py small_train.txt small_test.txt
base_lr=0.01, # base learning rate
batch_size=batch_size,
pass_num=10, # the number of passed for training
use_cuda=True, # whether to use GPU card
parallel=False, # whether to be parallel
use_cuda=use_cuda, # whether to use GPU card
parallel=parallel, # whether to be parallel
model_dir="model_recall20", # directory to save model
init_low_bound=-0.1, # uniform parameter initialization lower bound
init_high_bound=0.1) # uniform parameter initialization upper bound
......@@ -198,9 +200,9 @@ model saved in model_recall20/epoch_1
```
## 预测
运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
运行命令 `CUDA_VISIBLE_DEVICES=0 python infer.py model_dir start_epoch last_epoch(inclusive) train_file test_file` 开始预测.其中,start_epoch指定开始预测的轮次,last_epoch指定结束的轮次,例如
```python
CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt# prediction from epoch 1 to epoch 10 small_train.txt small_test.txt
CUDA_VISIBLE_DEVICES=0 python infer.py model 1 10 small_train.txt small_test.txt
```
## 预测结果示例
......
......@@ -17,7 +17,8 @@ def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument('train_file')
parser.add_argument('test_file')
parser.add_argument('--use_cuda', help='whether use gpu')
parser.add_argument('--parallel', help='whether parallel')
parser.add_argument(
'--enable_ce',
action='store_true',
......@@ -182,6 +183,9 @@ def train_net():
args = parse_args()
train_file = args.train_file
test_file = args.test_file
use_cuda = True if args.use_cuda else False
parallel = True if args.parallel else False
print("use_cuda:", use_cuda, "parallel:", parallel)
batch_size = 50
vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\
......@@ -194,8 +198,8 @@ def train_net():
base_lr=0.01,
batch_size=batch_size,
pass_num=10,
use_cuda=True,
parallel=False,
use_cuda=use_cuda,
parallel=parallel,
model_dir="model_recall20",
init_low_bound=-0.1,
init_high_bound=0.1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册