提交 83424e13 编写于 作者: F frankwhzhang

fix README, change train.py input format

上级 ce19c955
......@@ -74,15 +74,18 @@ python convert_format.py
```
## 训练
GPU 环境 默认配置
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file` 开始训练模型。
--use_cuda 1 表示使用gpu --parallel 1 表示多卡 0 表示单卡
GPU 环境
运行命令 `CUDA_VISIBLE_DEVICES=0 python train.py train_file test_file ` 开始训练模型。
```python
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.file
CUDA_VISIBLE_DEVICES=0 python train.py small_train.txt small_test.txt --use_cuda 1 --parallel 0
```
CPU 环境
运行命令 `python train.py train_file test_file` 开始训练模型。
```python
python train.py small_train.txt small_test.txt
python train.py small_train.txt small_test.txt --use_cuda 0 --parallel 0
```
当前支持的参数可参见[train.py](./train.py) `train_net` 函数
......
......@@ -17,7 +17,8 @@ def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument('train_file')
parser.add_argument('test_file')
parser.add_argument('--use_cuda', help='whether use gpu')
parser.add_argument('--parallel', help='whether parallel')
parser.add_argument(
'--enable_ce',
action='store_true',
......@@ -182,6 +183,8 @@ def train_net():
args = parse_args()
train_file = args.train_file
test_file = args.test_file
use_cuda = True if args.use_cuda else False
parallel = True if args.parallel else False
batch_size = 50
vocab, train_reader, test_reader = utils.prepare_data(
train_file, test_file,batch_size=batch_size * get_cards(args),\
......@@ -194,8 +197,8 @@ def train_net():
base_lr=0.01,
batch_size=batch_size,
pass_num=10,
use_cuda=True,
parallel=False,
use_cuda=use_cuda,
parallel=parallel,
model_dir="model_recall20",
init_low_bound=-0.1,
init_high_bound=0.1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册