提交 83424e13 编写于 作者: F frankwhzhang

fix README, change train.py input format

上级 ce19c955
...@@ -74,15 +74,18 @@ python convert_format.py ...@@ -74,15 +74,18 @@ python convert_format.py
``` ```
## 训练 ## 训练
GPU 环境 默认配置 --use_cuda 1 表示使用gpu --parallel 1 表示多卡 0 表示单卡
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file` 开始训练模型。
GPU 环境
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file ` 开始训练模型。
```python ```python
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.file CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.txt --use_cuda 1 --parallel 0
``` ```
CPU 环境 CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。 运行命令 `python train.py train_file test_file` 开始训练模型。
```python ```python
python train.py small_train.txt small_test.txt python train.py small_train.txt small_test.txt --use_cuda 0 --parallel 0
``` ```
当前支持的参数可参见[train.py](./train.py) `train_net` 函数 当前支持的参数可参见[train.py](./train.py) `train_net` 函数
......
...@@ -17,7 +17,8 @@ def parse_args(): ...@@ -17,7 +17,8 @@ def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.") parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument('train_file') parser.add_argument('train_file')
parser.add_argument('test_file') parser.add_argument('test_file')
parser.add_argument('--use_cuda', help='whether use gpu')
parser.add_argument('--parallel', help='whether parallel')
parser.add_argument( parser.add_argument(
'--enable_ce', '--enable_ce',
action='store_true', action='store_true',
...@@ -182,6 +183,8 @@ def train_net(): ...@@ -182,6 +183,8 @@ def train_net():
args = parse_args() args = parse_args()
train_file = args.train_file train_file = args.train_file
test_file = args.test_file test_file = args.test_file
use_cuda = True if args.use_cuda else False
parallel = True if args.parallel else False
batch_size = 50 batch_size = 50
vocab, train_reader, test_reader = utils.prepare_data( vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\ train_file, test_file,batch_size=batch_size * get_cards(args),\
...@@ -194,8 +197,8 @@ def train_net(): ...@@ -194,8 +197,8 @@ def train_net():
base_lr=0.01, base_lr=0.01,
batch_size=batch_size, batch_size=batch_size,
pass_num=10, pass_num=10,
use_cuda=True, use_cuda=use_cuda,
parallel=False, parallel=parallel,
model_dir="model_recall20", model_dir="model_recall20",
init_low_bound=-0.1, init_low_bound=-0.1,
init_high_bound=0.1) init_high_bound=0.1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册