提交 d2b7bac9 编写于 作者: G gmcather

add train

上级 796c965a
......@@ -85,9 +85,9 @@ def train_net():
batch_size = 128, buf_size = 50000)
if sys.argv[1] == "bow":
train(train_reader, word_dict, bow_net, use_cuda=False,
train(train_reader, word_dict, bow_net, use_cuda=True,
parallel=False, save_dirname="bow_model", lr=0.002,
pass_num=1, batch_size=128)
pass_num=30, batch_size=128)
elif sys.argv[1] == "cnn":
train(train_reader, word_dict, cnn_net, use_cuda=True,
parallel=False, save_dirname="cnn_model", lr=0.01,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册