提交 650078de 编写于 作者: Z Zeyu Chen 提交者: Yibing Liu

fix bug of gru_net train, remove useless parameter, fix emb learning rate of...

fix bug of gru_net train, remove useless parameter, fix emb learning rate of gru default setting (#1879)
上级 2c65b659
...@@ -101,7 +101,7 @@ def gru_net(data, ...@@ -101,7 +101,7 @@ def gru_net(data,
hid_dim=128, hid_dim=128,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
emb_lr=400.0): emb_lr=30.0):
""" """
gru net gru net
""" """
......
...@@ -22,7 +22,6 @@ def train(train_reader, ...@@ -22,7 +22,6 @@ def train(train_reader,
parallel, parallel,
save_dirname, save_dirname,
lr=0.2, lr=0.2,
batch_size=128,
pass_num=30): pass_num=30):
""" """
train network train network
...@@ -100,8 +99,7 @@ def train_net(): ...@@ -100,8 +99,7 @@ def train_net():
parallel=False, parallel=False,
save_dirname="bow_model", save_dirname="bow_model",
lr=0.002, lr=0.002,
pass_num=30, pass_num=30)
batch_size=4)
elif sys.argv[1] == "cnn": elif sys.argv[1] == "cnn":
train( train(
train_reader, train_reader,
...@@ -111,8 +109,7 @@ def train_net(): ...@@ -111,8 +109,7 @@ def train_net():
parallel=False, parallel=False,
save_dirname="cnn_model", save_dirname="cnn_model",
lr=0.01, lr=0.01,
pass_num=30, pass_num=30)
batch_size=4)
elif sys.argv[1] == "lstm": elif sys.argv[1] == "lstm":
train( train(
train_reader, train_reader,
...@@ -122,19 +119,17 @@ def train_net(): ...@@ -122,19 +119,17 @@ def train_net():
parallel=False, parallel=False,
save_dirname="lstm_model", save_dirname="lstm_model",
lr=0.05, lr=0.05,
pass_num=30, pass_num=30)
batch_size=4)
elif sys.argv[1] == "gru": elif sys.argv[1] == "gru":
train( train(
train_reader, train_reader,
word_dict, word_dict,
lstm_net, gru_net,
use_cuda=True, use_cuda=True,
parallel=False, parallel=False,
save_dirname="gru_model", save_dirname="gru_model",
lr=0.05, lr=0.05,
pass_num=30, pass_num=30)
batch_size=4)
else: else:
print("network name cannot be found!") print("network name cannot be found!")
sys.exit(1) sys.exit(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册