提交 3414ce9b 编写于 作者: J jshower

for code style

上级 ddf711a1
...@@ -141,8 +141,7 @@ def ner_net(word_dict_len, label_dict_len): ...@@ -141,8 +141,7 @@ def ner_net(word_dict_len, label_dict_len):
label=target, label=target,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='crfw', name='crfw',
learning_rate=0.2, learning_rate=0.2, ))
))
avg_cost = fluid.layers.mean(x=crf_cost) avg_cost = fluid.layers.mean(x=crf_cost)
return avg_cost, emission return avg_cost, emission
...@@ -166,11 +165,9 @@ def test2(exe, chunk_evaluator, inference_program, test_data, place, ...@@ -166,11 +165,9 @@ def test2(exe, chunk_evaluator, inference_program, test_data, place,
target = to_lodtensor(map(lambda x: x[2], data), place) target = to_lodtensor(map(lambda x: x[2], data), place)
result_list = exe.run( result_list = exe.run(
inference_program, inference_program,
feed={ feed={"word": word,
"word": word, "mention": mention,
"mention": mention, "target": target},
"target": target
},
fetch_list=cur_fetch_list) fetch_list=cur_fetch_list)
number_infer = np.array(result_list[0]) number_infer = np.array(result_list[0])
number_label = np.array(result_list[1]) number_label = np.array(result_list[1])
...@@ -189,16 +186,14 @@ def test(test_exe, chunk_evaluator, inference_program, test_data, place, ...@@ -189,16 +186,14 @@ def test(test_exe, chunk_evaluator, inference_program, test_data, place,
target = to_lodtensor(map(lambda x: x[2], data), place) target = to_lodtensor(map(lambda x: x[2], data), place)
result_list = test_exe.run( result_list = test_exe.run(
fetch_list=cur_fetch_list, fetch_list=cur_fetch_list,
feed={ feed={"word": word,
"word": word, "mention": mention,
"mention": mention, "target": target})
"target": target
})
number_infer = np.array(result_list[0]) number_infer = np.array(result_list[0])
number_label = np.array(result_list[1]) number_label = np.array(result_list[1])
number_correct = np.array(result_list[2]) number_correct = np.array(result_list[2])
chunk_evaluator.update(number_infer.sum(), number_label.sum(), chunk_evaluator.update(number_infer.sum(),
number_correct.sum()) number_label.sum(), number_correct.sum())
return chunk_evaluator.eval() return chunk_evaluator.eval()
...@@ -213,8 +208,8 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes): ...@@ -213,8 +208,8 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
avg_cost, feature_out, word, mention, target = ner_net( avg_cost, feature_out, word, mention, target = ner_net(word_dict_len,
word_dict_len, label_dict_len) label_dict_len)
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
...@@ -281,16 +276,16 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes): ...@@ -281,16 +276,16 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
except StopIteration: except StopIteration:
break break
end_time = time.time() end_time = time.time()
print("pass_id:" + str(pass_id) + ", time_cost:" + print("pass_id:" + str(pass_id) + ", time_cost:" + str(
str(end_time - start_time) + "s") end_time - start_time) + "s")
precision, recall, f1_score = chunk_evaluator.eval() precision, recall, f1_score = chunk_evaluator.eval()
print("[Train] precision:" + str(precision) + ", recall:" + print("[Train] precision:" + str(precision) + ", recall:" + str(
str(recall) + ", f1:" + str(f1_score)) recall) + ", f1:" + str(f1_score))
p, r, f1 = test2( p, r, f1 = test2(
exe, chunk_evaluator, inference_program, test_reader, place, exe, chunk_evaluator, inference_program, test_reader, place,
[num_infer_chunks, num_label_chunks, num_correct_chunks]) [num_infer_chunks, num_label_chunks, num_correct_chunks])
print("[Test] precision:" + str(p) + ", recall:" + str(r) + print("[Test] precision:" + str(p) + ", recall:" + str(r) + ", f1:"
", f1:" + str(f1)) + str(f1))
save_dirname = os.path.join(model_save_dir, save_dirname = os.path.join(model_save_dir,
"params_pass_%d" % pass_id) "params_pass_%d" % pass_id)
fluid.io.save_inference_model(save_dirname, ['word', 'mention'], fluid.io.save_inference_model(save_dirname, ['word', 'mention'],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册