提交 ddf711a1 编写于 作者: J jshower

for code review

上级 28cf0218
......@@ -141,7 +141,8 @@ def ner_net(word_dict_len, label_dict_len):
label=target,
param_attr=fluid.ParamAttr(
name='crfw',
learning_rate=0.2, ))
learning_rate=0.2,
))
avg_cost = fluid.layers.mean(x=crf_cost)
return avg_cost, emission
......@@ -165,9 +166,11 @@ def test2(exe, chunk_evaluator, inference_program, test_data, place,
target = to_lodtensor(map(lambda x: x[2], data), place)
result_list = exe.run(
inference_program,
feed={"word": word,
"mention": mention,
"target": target},
feed={
"word": word,
"mention": mention,
"target": target
},
fetch_list=cur_fetch_list)
number_infer = np.array(result_list[0])
number_label = np.array(result_list[1])
......@@ -186,14 +189,16 @@ def test(test_exe, chunk_evaluator, inference_program, test_data, place,
target = to_lodtensor(map(lambda x: x[2], data), place)
result_list = test_exe.run(
fetch_list=cur_fetch_list,
feed={"word": word,
"mention": mention,
"target": target})
feed={
"word": word,
"mention": mention,
"target": target
})
number_infer = np.array(result_list[0])
number_label = np.array(result_list[1])
number_correct = np.array(result_list[2])
chunk_evaluator.update(number_infer.sum(),
number_label.sum(), number_correct.sum())
chunk_evaluator.update(number_infer.sum(), number_label.sum(),
number_correct.sum())
return chunk_evaluator.eval()
......@@ -208,12 +213,11 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
avg_cost, feature_out, word, mention, target = ner_net(word_dict_len,
label_dict_len)
avg_cost, feature_out, word, mention, target = ner_net(
word_dict_len, label_dict_len)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(
name='crfw'))
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost)
......@@ -277,20 +281,20 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
except StopIteration:
break
end_time = time.time()
print("pass_id:" + str(pass_id) + ", time_cost:" + str(
end_time - start_time) + "s")
print("pass_id:" + str(pass_id) + ", time_cost:" +
str(end_time - start_time) + "s")
precision, recall, f1_score = chunk_evaluator.eval()
print("[Train] precision:" + str(precision) + ", recall:" + str(
recall) + ", f1:" + str(f1_score))
print("[Train] precision:" + str(precision) + ", recall:" +
str(recall) + ", f1:" + str(f1_score))
p, r, f1 = test2(
exe, chunk_evaluator, inference_program, test_reader, place,
[num_infer_chunks, num_label_chunks, num_correct_chunks])
print("[Test] precision:" + str(p) + ", recall:" + str(r) + ", f1:"
+ str(f1))
print("[Test] precision:" + str(p) + ", recall:" + str(r) +
", f1:" + str(f1))
save_dirname = os.path.join(model_save_dir,
"params_pass_%d" % pass_id)
fluid.io.save_inference_model(
save_dirname, ['word', 'mention'], [crf_decode], exe)
fluid.io.save_inference_model(save_dirname, ['word', 'mention'],
[crf_decode], exe)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册