提交 8119706b 编写于 作者: Y Yibing Liu

Decrease batch_size to run demo of sequence_tagging_for_ner_ce

上级 43b3541f
......@@ -2,7 +2,7 @@ import os
import math
import numpy as np
import paddle.v2 as paddle
import paddle
import paddle.fluid as fluid
import reader
......@@ -24,12 +24,19 @@ def test(exe, chunk_evaluator, inference_program, test_data, place):
return chunk_evaluator.eval(exe)
def main(train_data_file, test_data_file, vocab_file, target_file, emb_file,
model_save_dir, num_passes, use_gpu, parallel):
def main(train_data_file,
test_data_file,
vocab_file,
target_file,
emb_file,
model_save_dir,
num_passes,
use_gpu,
parallel,
batch_size=200):
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
BATCH_SIZE = 200
word_dict = load_dict(vocab_file)
label_dict = load_dict(target_file)
......@@ -62,12 +69,12 @@ def main(train_data_file, test_data_file, vocab_file, target_file, emb_file,
paddle.reader.shuffle(
reader.data_reader(train_data_file, word_dict, label_dict),
buf_size=20000),
batch_size=BATCH_SIZE)
batch_size=batch_size)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.data_reader(test_data_file, word_dict, label_dict),
buf_size=20000),
batch_size=BATCH_SIZE)
batch_size=batch_size)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[word, mark, target], place=place)
......@@ -79,34 +86,33 @@ def main(train_data_file, test_data_file, vocab_file, target_file, emb_file,
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(word_vector_values, place)
batch_id = 0
for pass_id in xrange(num_passes):
chunk_evaluator.reset(exe)
for data in train_reader():
for batch_id, data in enumerate(train_reader()):
cost, batch_precision, batch_recall, batch_f1_score = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + chunk_evaluator.metrics)
if batch_id % 5 == 0:
print(cost)
print("Pass " + str(pass_id) + ", Batch " + str(
batch_id) + ", Cost " + str(cost[0]) + ", Precision " + str(
batch_precision[0]) + ", Recall " + str(batch_recall[0])
+ ", F1_score" + str(batch_f1_score[0]))
batch_id = batch_id + 1
pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval(exe)
print("[TrainSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score))
pass_precision, pass_recall, pass_f1_score = test(
test_pass_precision, test_pass_recall, test_pass_f1_score = test(
exe, chunk_evaluator, inference_program, test_reader, place)
print("[TestSet] pass_id:" + str(pass_id) + " pass_precision:" + str(
pass_precision) + " pass_recall:" + str(pass_recall) +
" pass_f1_score:" + str(pass_f1_score))
test_pass_precision) + " pass_recall:" + str(test_pass_recall) +
" pass_f1_score:" + str(test_pass_f1_score))
save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id)
fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'],
[crf_decode], exe)
crf_decode, exe)
if __name__ == "__main__":
......@@ -117,6 +123,7 @@ if __name__ == "__main__":
target_file="data/target.txt",
emb_file="data/wordVectors.txt",
model_save_dir="models",
num_passes=1000,
num_passes=100,
batch_size=1,
use_gpu=False,
parallel=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册