predict.py 5.8 KB
Newer Older
0
0YuanZhang0 已提交
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
Y
Yibing Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import numpy as np
0
0YuanZhang0 已提交
18 19
import argparse
import collections
Y
Yibing Liu 已提交
20 21 22
import paddle
import paddle.fluid as fluid

0
0YuanZhang0 已提交
23 24 25 26
import dgu.reader as reader
from dgu_net import create_net
import dgu.define_paradigm as define_paradigm 
import dgu.define_predict_pack as define_predict_pack
Y
Yibing Liu 已提交
27

0
0YuanZhang0 已提交
28 29 30 31
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
Y
Yibing Liu 已提交
32

P
pkpk 已提交
33

0
0YuanZhang0 已提交
34 35
def do_predict(args): 
    """predict function"""
Y
Yibing Liu 已提交
36 37 38 39 40 41 42 43 44 45

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)
    pred_inst = define_predict_pack.DefinePredict()
    pred_func = getattr(pred_inst, pred_inst.task_map[task_name])

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
P
pkpk 已提交
46
        'atis_slot': reader.ATISSlotProcessor,
Y
Yibing Liu 已提交
47
        'atis_intent': reader.ATISIntentProcessor,
P
pkpk 已提交
48
        'dstc2': reader.DSTC2Processor,
Y
Yibing Liu 已提交
49 50
    }

0
0YuanZhang0 已提交
51 52
    test_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
Y
Yibing Liu 已提交
53

0
0YuanZhang0 已提交
54 55 56
    with fluid.program_guard(test_prog, startup_prog):
        test_prog.random_seed = args.random_seed
        startup_prog.random_seed = args.random_seed
Y
Yibing Liu 已提交
57 58

        with fluid.unique_name.guard():
0
0YuanZhang0 已提交
59 60 61 62

            # define inputs of the network
            num_labels = len(processors[task_name].get_labels())

0
0YuanZhang0 已提交
63 64 65 66 67 68 69 70
            src_ids = fluid.data(
                        name='src_ids', shape=[-1, args.max_seq_len], dtype='int64')
            pos_ids = fluid.data(
                        name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64')
            sent_ids = fluid.data(
                        name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
            input_mask = fluid.data(
                        name='input_mask', shape=[-1, args.max_seq_len], dtype='float32')
0
0YuanZhang0 已提交
71
            if args.task_name == 'atis_slot': 
0
0YuanZhang0 已提交
72 73
                labels = fluid.data(
                        name='labels', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
74
            elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
0
0YuanZhang0 已提交
75 76
                labels = fluid.data(
                        name='labels', shape=[-1, num_labels], dtype='int64')
0
0YuanZhang0 已提交
77
            else: 
0
0YuanZhang0 已提交
78 79
                labels = fluid.data(
                        name='labels', shape=[-1, 1], dtype='int64')
0
0YuanZhang0 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
            
            input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
            input_field = InputField(input_inst)
            data_reader = fluid.io.PyReader(feed_list=input_inst, 
                        capacity=4, iterable=False)
            
            results = create_net(
                    is_training=False, 
                    model_input=input_field, 
                    num_labels=num_labels,
                    paradigm_inst=paradigm_inst,
                    args=args)

            probs = results.get("probs", None)

            probs.persistable = True

            fetch_list = [probs.name]

    #for_test is True if change the is_test attribute of operators to True
    test_prog = test_prog.clone(for_test=True)
Y
Yibing Liu 已提交
101 102

    if args.use_cuda:
0
0YuanZhang0 已提交
103
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
Y
Yibing Liu 已提交
104 105 106 107
    else:
        place = fluid.CPUPlace()

    exe = fluid.Executor(place)
0
0YuanZhang0 已提交
108
    exe.run(startup_prog)
Y
Yibing Liu 已提交
109

0
0YuanZhang0 已提交
110
    assert (args.init_from_params) or (args.init_from_pretrain_model)
Y
Yibing Liu 已提交
111

0
0YuanZhang0 已提交
112 113 114 115
    if args.init_from_params:
        save_load_io.init_from_params(args, exe, test_prog)
    if args.init_from_pretrain_model:
        save_load_io.init_from_pretrain_model(args, exe, test_prog)
Y
Yibing Liu 已提交
116

0
0YuanZhang0 已提交
117
    compiled_test_prog = fluid.CompiledProgram(test_prog)
0
0YuanZhang0 已提交
118
   
0
0YuanZhang0 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
                                      do_lower_case=args.do_lower_case,
                                      in_tokens=args.in_tokens,
                                      task_name=task_name,
                                      random_seed=args.random_seed)
    batch_generator = processor.data_generator(
        batch_size=args.batch_size,
        phase='test',
        shuffle=False)

    data_reader.decorate_batch_generator(batch_generator) 
    data_reader.start()
    
Y
Yibing Liu 已提交
134
    all_results = []
0
0YuanZhang0 已提交
135 136 137
    while True: 
        try: 
            results = exe.run(compiled_test_prog, fetch_list=fetch_list)
Y
Yibing Liu 已提交
138
            all_results.extend(results[0])
0
0YuanZhang0 已提交
139 140
        except fluid.core.EOFException: 
            data_reader.reset()
Y
Yibing Liu 已提交
141 142 143
            break

    np.set_printoptions(precision=4, suppress=True)
0
0YuanZhang0 已提交
144
    print("Write the predicted results into the output_prediction_file")
0
0YuanZhang0 已提交
145 146 147 148 149 150 151 152 153
    with open(args.output_prediction_file, 'w') as fw: 
        if task_name not in ['atis_slot']: 
            for index, result in enumerate(all_results):
                tags = pred_func(result)
                fw.write("%s\t%s\n" % (index, tags))
        else:
            tags = pred_func(all_results, args.max_seq_len)
            for index, tag in enumerate(tags):
                fw.write("%s\t%s\n" % (index, tag))
P
pkpk 已提交
154

Y
Yibing Liu 已提交
155

0
0YuanZhang0 已提交
156
if __name__ == "__main__":
Y
Yibing Liu 已提交
157

0
0YuanZhang0 已提交
158 159 160
    args = PDConfig(yaml_file="./data/config/dgu.yaml")
    args.build()
    args.Print()
P
pkpk 已提交
161 162 163

    check_cuda(args.use_cuda)

0
0YuanZhang0 已提交
164
    do_predict(args)