提交 6b8b92b2 编写于 作者: J jshower

fix ISSUE=1534

上级 8df8d71e
......@@ -38,12 +38,10 @@ def infer(model_path, batch_size, test_data_file, vocab_file, target_file,
for data in test_data():
word = to_lodtensor([x[0] for x in data], place)
mark = to_lodtensor([x[1] for x in data], place)
target = to_lodtensor([x[2] for x in data], place)
crf_decode = exe.run(
inference_program,
feed={"word": word,
"mark": mark,
"target": target},
"mark": mark},
fetch_list=fetch_targets,
return_numpy=False)
lod_info = (crf_decode[0].lod())[0]
......@@ -66,7 +64,7 @@ def infer(model_path, batch_size, test_data_file, vocab_file, target_file,
if __name__ == "__main__":
infer(
model_path="models/params_pass_0",
model_path="models/params_pass_16",
batch_size=6,
test_data_file="data/test",
vocab_file="data/vocab.txt",
......
......@@ -61,22 +61,26 @@ def main(train_data_file,
avg_cost, feature_out, word, mark, target = ner_net(
word_dict_len, label_dict_len, parallel)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
(precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
input=crf_decode,
label=target,
chunk_scheme="IOB",
num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0)))
chunk_evaluator = fluid.metrics.ChunkEvaluator()
inference_program = fluid.default_main_program().clone(for_test=True)
test_fetch_list = [num_infer_chunks, num_label_chunks, num_correct_chunks]
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-2)
sgd_optimizer.minimize(avg_cost)
#print(inference_program)
#exit(0)
#test_fetch_list = [num_infer_chunks, num_label_chunks, num_correct_chunks]
if "CE_MODE_X" not in os.environ:
train_reader = paddle.batch(
......@@ -111,6 +115,7 @@ def main(train_data_file,
for pass_id in six.moves.xrange(num_passes):
chunk_evaluator.reset()
for batch_id, data in enumerate(train_reader()):
#print(data)
cost_var, nums_infer, nums_label, nums_correct = exe.run(
fluid.default_main_program(),
feed=feeder.feed(data),
......@@ -135,7 +140,7 @@ def main(train_data_file,
" pass_f1_score:" + str(test_pass_f1_score))
save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id)
fluid.io.save_inference_model(save_dirname, ['word', 'mark', 'target'],
fluid.io.save_inference_model(save_dirname, ['word', 'mark'],
crf_decode, exe)
if "CE_MODE_X" in os.environ:
......@@ -146,8 +151,8 @@ def main(train_data_file,
if __name__ == "__main__":
main(
train_data_file="data/train",
test_data_file="data/test",
train_data_file="data/train.txt",
test_data_file="data/test.txt",
vocab_file="data/vocab.txt",
target_file="data/target.txt",
emb_file="data/wordVectors.txt",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册