From ddf711a111075242b56f5a6a3fdc2cae949886ed Mon Sep 17 00:00:00 2001 From: jshower Date: Mon, 4 Jun 2018 19:02:20 +0800 Subject: [PATCH] for code review --- fluid/chinese_ner/train.py | 46 +++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/fluid/chinese_ner/train.py b/fluid/chinese_ner/train.py index 763557b2..61c589bf 100644 --- a/fluid/chinese_ner/train.py +++ b/fluid/chinese_ner/train.py @@ -141,7 +141,8 @@ def ner_net(word_dict_len, label_dict_len): label=target, param_attr=fluid.ParamAttr( name='crfw', - learning_rate=0.2, )) + learning_rate=0.2, + )) avg_cost = fluid.layers.mean(x=crf_cost) return avg_cost, emission @@ -165,9 +166,11 @@ def test2(exe, chunk_evaluator, inference_program, test_data, place, target = to_lodtensor(map(lambda x: x[2], data), place) result_list = exe.run( inference_program, - feed={"word": word, - "mention": mention, - "target": target}, + feed={ + "word": word, + "mention": mention, + "target": target + }, fetch_list=cur_fetch_list) number_infer = np.array(result_list[0]) number_label = np.array(result_list[1]) @@ -186,14 +189,16 @@ def test(test_exe, chunk_evaluator, inference_program, test_data, place, target = to_lodtensor(map(lambda x: x[2], data), place) result_list = test_exe.run( fetch_list=cur_fetch_list, - feed={"word": word, - "mention": mention, - "target": target}) + feed={ + "word": word, + "mention": mention, + "target": target + }) number_infer = np.array(result_list[0]) number_label = np.array(result_list[1]) number_correct = np.array(result_list[2]) - chunk_evaluator.update(number_infer.sum(), - number_label.sum(), number_correct.sum()) + chunk_evaluator.update(number_infer.sum(), number_label.sum(), + number_correct.sum()) return chunk_evaluator.eval() @@ -208,12 +213,11 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): - avg_cost, feature_out, word, mention, target = ner_net(word_dict_len, - label_dict_len) + avg_cost, feature_out, word, mention, target = ner_net( + word_dict_len, label_dict_len) crf_decode = fluid.layers.crf_decoding( - input=feature_out, param_attr=fluid.ParamAttr( - name='crfw')) + input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3) sgd_optimizer.minimize(avg_cost) @@ -277,20 +281,20 @@ def main(train_data_file, test_data_file, model_save_dir, num_passes): except StopIteration: break end_time = time.time() - print("pass_id:" + str(pass_id) + ", time_cost:" + str( - end_time - start_time) + "s") + print("pass_id:" + str(pass_id) + ", time_cost:" + + str(end_time - start_time) + "s") precision, recall, f1_score = chunk_evaluator.eval() - print("[Train] precision:" + str(precision) + ", recall:" + str( - recall) + ", f1:" + str(f1_score)) + print("[Train] precision:" + str(precision) + ", recall:" + + str(recall) + ", f1:" + str(f1_score)) p, r, f1 = test2( exe, chunk_evaluator, inference_program, test_reader, place, [num_infer_chunks, num_label_chunks, num_correct_chunks]) - print("[Test] precision:" + str(p) + ", recall:" + str(r) + ", f1:" - + str(f1)) + print("[Test] precision:" + str(p) + ", recall:" + str(r) + + ", f1:" + str(f1)) save_dirname = os.path.join(model_save_dir, "params_pass_%d" % pass_id) - fluid.io.save_inference_model( - save_dirname, ['word', 'mention'], [crf_decode], exe) + fluid.io.save_inference_model(save_dirname, ['word', 'mention'], + [crf_decode], exe) if __name__ == "__main__": -- GitLab