diff --git a/fluid/text_classification/train.py b/fluid/text_classification/train.py index b2ffe4c6723120103c9b3e310b070f4c773aeeb4..74264ce5d720a345aa1bf793510b75cf2ca74346 100644 --- a/fluid/text_classification/train.py +++ b/fluid/text_classification/train.py @@ -89,7 +89,7 @@ def train(train_reader, def train_net(): word_dict, train_reader, test_reader = utils.prepare_data( - "imdb", self_dict=False, batch_size=128, buf_size=50000) + "imdb", self_dict=False, batch_size=4, buf_size=50000) if sys.argv[1] == "bow": train( @@ -101,7 +101,7 @@ def train_net(): save_dirname="bow_model", lr=0.002, pass_num=30, - batch_size=128) + batch_size=4) elif sys.argv[1] == "cnn": train( train_reader, @@ -134,7 +134,7 @@ def train_net(): save_dirname="gru_model", lr=0.05, pass_num=30, - batch_size=128) + batch_size=4) else: print("network name cannot be found!") sys.exit(1)