predict.py 5.8 KB
Newer Older
0
0YuanZhang0 已提交
1
# -*- coding: utf-8 -*-
0
0YuanZhang0 已提交
2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
Y
Yibing Liu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

0
0YuanZhang0 已提交
16
import io
Y
Yibing Liu 已提交
17 18 19
import os
import sys
import numpy as np
0
0YuanZhang0 已提交
20 21
import argparse
import collections
Y
Yibing Liu 已提交
22 23 24
import paddle
import paddle.fluid as fluid

0
0YuanZhang0 已提交
25 26 27 28
import dgu.reader as reader
from dgu_net import create_net
import dgu.define_paradigm as define_paradigm 
import dgu.define_predict_pack as define_predict_pack
Y
Yibing Liu 已提交
29

0
0YuanZhang0 已提交
30 31 32 33
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
Y
Yibing Liu 已提交
34

P
pkpk 已提交
35

0
0YuanZhang0 已提交
36 37
def do_predict(args): 
    """predict function"""
Y
Yibing Liu 已提交
38 39 40 41 42 43 44 45 46 47

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)
    pred_inst = define_predict_pack.DefinePredict()
    pred_func = getattr(pred_inst, pred_inst.task_map[task_name])

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
P
pkpk 已提交
48
        'atis_slot': reader.ATISSlotProcessor,
Y
Yibing Liu 已提交
49
        'atis_intent': reader.ATISIntentProcessor,
P
pkpk 已提交
50
        'dstc2': reader.DSTC2Processor,
Y
Yibing Liu 已提交
51 52
    }

0
0YuanZhang0 已提交
53 54
    test_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
Y
Yibing Liu 已提交
55

0
0YuanZhang0 已提交
56 57 58
    with fluid.program_guard(test_prog, startup_prog):
        test_prog.random_seed = args.random_seed
        startup_prog.random_seed = args.random_seed
Y
Yibing Liu 已提交
59 60

        with fluid.unique_name.guard():
0
0YuanZhang0 已提交
61 62 63 64

            # define inputs of the network
            num_labels = len(processors[task_name].get_labels())

0
0YuanZhang0 已提交
65 66 67 68 69 70 71 72
            src_ids = fluid.data(
                        name='src_ids', shape=[-1, args.max_seq_len], dtype='int64')
            pos_ids = fluid.data(
                        name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64')
            sent_ids = fluid.data(
                        name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
            input_mask = fluid.data(
                        name='input_mask', shape=[-1, args.max_seq_len], dtype='float32')
0
0YuanZhang0 已提交
73
            if args.task_name == 'atis_slot': 
0
0YuanZhang0 已提交
74 75
                labels = fluid.data(
                        name='labels', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
76
            elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
0
0YuanZhang0 已提交
77 78
                labels = fluid.data(
                        name='labels', shape=[-1, num_labels], dtype='int64')
0
0YuanZhang0 已提交
79
            else: 
0
0YuanZhang0 已提交
80 81
                labels = fluid.data(
                        name='labels', shape=[-1, 1], dtype='int64')
0
0YuanZhang0 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
            
            input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
            input_field = InputField(input_inst)
            data_reader = fluid.io.PyReader(feed_list=input_inst, 
                        capacity=4, iterable=False)
            
            results = create_net(
                    is_training=False, 
                    model_input=input_field, 
                    num_labels=num_labels,
                    paradigm_inst=paradigm_inst,
                    args=args)

            probs = results.get("probs", None)

            probs.persistable = True

            fetch_list = [probs.name]

    #for_test is True if change the is_test attribute of operators to True
    test_prog = test_prog.clone(for_test=True)
Y
Yibing Liu 已提交
103 104

    if args.use_cuda:
0
0YuanZhang0 已提交
105
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
Y
Yibing Liu 已提交
106 107 108 109
    else:
        place = fluid.CPUPlace()

    exe = fluid.Executor(place)
0
0YuanZhang0 已提交
110
    exe.run(startup_prog)
Y
Yibing Liu 已提交
111

0
0YuanZhang0 已提交
112
    assert (args.init_from_params) or (args.init_from_pretrain_model)
Y
Yibing Liu 已提交
113

0
0YuanZhang0 已提交
114 115 116 117
    if args.init_from_params:
        save_load_io.init_from_params(args, exe, test_prog)
    if args.init_from_pretrain_model:
        save_load_io.init_from_pretrain_model(args, exe, test_prog)
Y
Yibing Liu 已提交
118

0
0YuanZhang0 已提交
119
    compiled_test_prog = fluid.CompiledProgram(test_prog)
0
0YuanZhang0 已提交
120
   
0
0YuanZhang0 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
                                      do_lower_case=args.do_lower_case,
                                      in_tokens=args.in_tokens,
                                      task_name=task_name,
                                      random_seed=args.random_seed)
    batch_generator = processor.data_generator(
        batch_size=args.batch_size,
        phase='test',
        shuffle=False)

    data_reader.decorate_batch_generator(batch_generator) 
    data_reader.start()
    
Y
Yibing Liu 已提交
136
    all_results = []
0
0YuanZhang0 已提交
137 138 139
    while True: 
        try: 
            results = exe.run(compiled_test_prog, fetch_list=fetch_list)
Y
Yibing Liu 已提交
140
            all_results.extend(results[0])
0
0YuanZhang0 已提交
141 142
        except fluid.core.EOFException: 
            data_reader.reset()
Y
Yibing Liu 已提交
143 144 145
            break

    np.set_printoptions(precision=4, suppress=True)
0
0YuanZhang0 已提交
146
    print("Write the predicted results into the output_prediction_file")
0
0YuanZhang0 已提交
147 148 149 150 151 152 153 154 155 156
    
    fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
    if task_name not in ['atis_slot']: 
        for index, result in enumerate(all_results):
            tags = pred_func(result)
            fw.write("%s\t%s\n" % (index, tags))
    else:
        tags = pred_func(all_results, args.max_seq_len)
        for index, tag in enumerate(tags):
            fw.write("%s\t%s\n" % (index, tag))
P
pkpk 已提交
157

Y
Yibing Liu 已提交
158

0
0YuanZhang0 已提交
159
if __name__ == "__main__":
Y
Yibing Liu 已提交
160

0
0YuanZhang0 已提交
161 162 163
    args = PDConfig(yaml_file="./data/config/dgu.yaml")
    args.build()
    args.Print()
P
pkpk 已提交
164 165 166

    check_cuda(args.use_cuda)

0
0YuanZhang0 已提交
167
    do_predict(args)