diff --git a/dygraph/lac/eval.py b/dygraph/lac/eval.py index f3bbd82cf51ba195e19fdd16806f0474299b48f7..03b41effd6eb20081564be689e593aca3baa3c19 100755 --- a/dygraph/lac/eval.py +++ b/dygraph/lac/eval.py @@ -49,7 +49,7 @@ def do_eval(args): load_path = args.init_checkpoint state_dict, _ = fluid.dygraph.load_dygraph(load_path) #import ipdb; ipdb.set_trace() - state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"] + state_dict["linear_chain_crf.weight"]=state_dict["crf_decoding.weight"] model.set_dict(state_dict) model.eval() chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB") diff --git a/dygraph/lac/predict.py b/dygraph/lac/predict.py index ab22b70db1410e73cc196f99eed71c6cac1b833e..6431f76b79a7442effc3180b103d3994399c9259 100755 --- a/dygraph/lac/predict.py +++ b/dygraph/lac/predict.py @@ -50,7 +50,7 @@ def do_infer(args): load_path = args.init_checkpoint state_dict, _ = fluid.dygraph.load_dygraph(load_path) #import ipdb; ipdb.set_trace() - state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"] + state_dict["linear_chain_crf.weight"]=state_dict["crf_decoding.weight"] model.set_dict(state_dict) model.eval() chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")