predict.py 3.0 KB
Newer Older
X
Xing Wu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
# -*- coding: UTF-8 -*-
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import time
import sys

import paddle.fluid as fluid
import paddle
import utils
import reader
import math
from sequence_labeling import lex_net, Chunk_eval
parser = argparse.ArgumentParser(__doc__)
# 1. model parameters
utils.load_yaml(parser, 'conf/args.yaml')
args = parser.parse_args()

def do_infer(args):
    dataset = reader.Dataset(args)

    if args.use_cuda:
        place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
        if args.use_data_parallel else fluid.CUDAPlace(0)
    else:
        place = fluid.CPUPlace()

    with fluid.dygraph.guard(place):
        infer_loader = reader.create_dataloader(
            args,
            file_name=args.infer_data,
            place=place,
            model='lac',
            reader=dataset,
            mode='infer')
        model = lex_net(args, dataset.vocab_size, dataset.num_labels)
        load_path = args.init_checkpoint
        state_dict, _ = fluid.dygraph.load_dygraph(load_path)
        #import ipdb; ipdb.set_trace()
        state_dict["crf_decoding_0.crfw"]=state_dict["linear_chain_crf_0.crfw"]
        model.set_dict(state_dict)
        model.eval()
        chunk_eval = Chunk_eval(int(math.ceil((dataset.num_labels - 1) / 2.0)), "IOB")
        chunk_evaluator = fluid.metrics.ChunkEvaluator()
        chunk_evaluator.reset()

        def input_check(data):
       	    if data.lod()[0][-1] == 0:
                return data[0]['words']
            return None
            
        def infer_process(reader):
            results = []
           
            for batch in reader():
                # import ipdb; ipdb.set_trace()
                words, length = batch
                #crf_decode = input_check(words)
                #if crf_decode:
                #    results += utils.parse_result(crf_decode, crf_decode, dataset)
                #    continue
			              
                crf_decode = model(words, length=length)
                results += utils.parse_padding_result(words.numpy(), crf_decode.numpy(), length.numpy(), dataset)          
            return results
            
        result = infer_process(infer_loader)
        for sent, tags in result:
            result_list = ['(%s, %s)' % (ch, tag) for ch, tag in zip(sent, tags)]
            print(''.join(result_list))

if __name__ == '__main__':
    args = parser.parse_args()
    do_infer(args)