diff --git a/fluid/PaddleNLP/text_classification/nets.py b/fluid/PaddleNLP/text_classification/nets.py index 6ba637dd087afd8e45ad8d0752ac9850ec49e627..4a7caad99f89ae6db0a748634a7c9b0d6632f2ec 100644 --- a/fluid/PaddleNLP/text_classification/nets.py +++ b/fluid/PaddleNLP/text_classification/nets.py @@ -101,7 +101,7 @@ def gru_net(data, hid_dim=128, hid_dim2=96, class_dim=2, - emb_lr=400.0): + emb_lr=30.0): """ gru net """ diff --git a/fluid/PaddleNLP/text_classification/train.py b/fluid/PaddleNLP/text_classification/train.py index 174636f06ec5fe07180347745f910166140e9eed..a6978a15d2d58a91998b6941a438d804e3e0ee5e 100644 --- a/fluid/PaddleNLP/text_classification/train.py +++ b/fluid/PaddleNLP/text_classification/train.py @@ -22,7 +22,6 @@ def train(train_reader, parallel, save_dirname, lr=0.2, - batch_size=128, pass_num=30): """ train network @@ -100,8 +99,7 @@ def train_net(): parallel=False, save_dirname="bow_model", lr=0.002, - pass_num=30, - batch_size=4) + pass_num=30) elif sys.argv[1] == "cnn": train( train_reader, @@ -111,8 +109,7 @@ def train_net(): parallel=False, save_dirname="cnn_model", lr=0.01, - pass_num=30, - batch_size=4) + pass_num=30) elif sys.argv[1] == "lstm": train( train_reader, @@ -122,19 +119,17 @@ def train_net(): parallel=False, save_dirname="lstm_model", lr=0.05, - pass_num=30, - batch_size=4) + pass_num=30) elif sys.argv[1] == "gru": train( train_reader, word_dict, - lstm_net, + gru_net, use_cuda=True, parallel=False, save_dirname="gru_model", lr=0.05, - pass_num=30, - batch_size=4) + pass_num=30) else: print("network name cannot be found!") sys.exit(1)