diff --git a/fluid/text_classification/train.py b/fluid/text_classification/train.py index e2514090fae3fac8546f6e8e08cd2ec0c52c4b32..0e295882134321f4609ae77a5eb230376c3dc2e7 100644 --- a/fluid/text_classification/train.py +++ b/fluid/text_classification/train.py @@ -85,9 +85,9 @@ def train_net(): batch_size = 128, buf_size = 50000) if sys.argv[1] == "bow": - train(train_reader, word_dict, bow_net, use_cuda=False, + train(train_reader, word_dict, bow_net, use_cuda=True, parallel=False, save_dirname="bow_model", lr=0.002, - pass_num=1, batch_size=128) + pass_num=30, batch_size=128) elif sys.argv[1] == "cnn": train(train_reader, word_dict, cnn_net, use_cuda=True, parallel=False, save_dirname="cnn_model", lr=0.01,