未验证 提交 9983e3a9 编写于 作者: W wuxing_iie 提交者: GitHub

fix_lac_dygraph_to_1.7 (#4258)

Co-authored-by: NXing Wu <wuxing03@baidu.com>
上级 f69a0b5e
...@@ -49,7 +49,7 @@ def do_eval(args): ...@@ -49,7 +49,7 @@ def do_eval(args):
load_path = args.init_checkpoint load_path = args.init_checkpoint
state_dict, _ = fluid.dygraph.load_dygraph(load_path) state_dict, _ = fluid.dygraph.load_dygraph(load_path)
#import ipdb; ipdb.set_trace() #import ipdb; ipdb.set_trace()
state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"] state_dict["linear_chain_crf.weight"]=state_dict["crf_decoding.weight"]
model.set_dict(state_dict) model.set_dict(state_dict)
model.eval() model.eval()
chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB") chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
......
...@@ -50,7 +50,7 @@ def do_infer(args): ...@@ -50,7 +50,7 @@ def do_infer(args):
load_path = args.init_checkpoint load_path = args.init_checkpoint
state_dict, _ = fluid.dygraph.load_dygraph(load_path) state_dict, _ = fluid.dygraph.load_dygraph(load_path)
#import ipdb; ipdb.set_trace() #import ipdb; ipdb.set_trace()
state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"] state_dict["linear_chain_crf.weight"]=state_dict["crf_decoding.weight"]
model.set_dict(state_dict) model.set_dict(state_dict)
model.eval() model.eval()
chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB") chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册