提交 e3e2ca05 编写于 作者: G gmcather

fix bug: batch_size diff

上级 14139f8f
......@@ -89,7 +89,7 @@ def train(train_reader,
def train_net():
word_dict, train_reader, test_reader = utils.prepare_data(
"imdb", self_dict=False, batch_size=128, buf_size=50000)
"imdb", self_dict=False, batch_size=4, buf_size=50000)
if sys.argv[1] == "bow":
train(
......@@ -101,7 +101,7 @@ def train_net():
save_dirname="bow_model",
lr=0.002,
pass_num=30,
batch_size=128)
batch_size=4)
elif sys.argv[1] == "cnn":
train(
train_reader,
......@@ -134,7 +134,7 @@ def train_net():
save_dirname="gru_model",
lr=0.05,
pass_num=30,
batch_size=128)
batch_size=4)
else:
print("network name cannot be found!")
sys.exit(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册