未验证 提交 271d9585 编写于 作者: J jiaozhenyu 提交者: GitHub

Merge pull request #961 from jshower/develop

fix inference bug
...@@ -211,13 +211,12 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes): ...@@ -211,13 +211,12 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
avg_cost, feature_out, word, mention, target = ner_net(word_dict_len, avg_cost, feature_out, word, mention, target = ner_net(word_dict_len,
label_dict_len) label_dict_len)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(name='crfw'))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3) sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
crf_decode = fluid.layers.crf_decoding(
input=feature_out, param_attr=fluid.ParamAttr(
name='crfw', ))
(precision, recall, f1_score, num_infer_chunks, num_label_chunks, (precision, recall, f1_score, num_infer_chunks, num_label_chunks,
num_correct_chunks) = fluid.layers.chunk_eval( num_correct_chunks) = fluid.layers.chunk_eval(
input=crf_decode, input=crf_decode,
...@@ -289,8 +288,8 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes): ...@@ -289,8 +288,8 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes):
+ str(f1)) + str(f1))
save_dirname = os.path.join(model_save_dir, save_dirname = os.path.join(model_save_dir,
"params_pass_%d" % pass_id) "params_pass_%d" % pass_id)
fluid.io.save_inference_model( fluid.io.save_inference_model(save_dirname, ['word', 'mention'],
save_dirname, ['word', 'mention', 'target'], [crf_decode], exe) [crf_decode], exe)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册