predict.py 5.6 KB
Newer Older
0
0YuanZhang0 已提交
1
# -*- coding: utf-8 -*-
0
0YuanZhang0 已提交
2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
Y
Yibing Liu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

0
0YuanZhang0 已提交
16
import io
Y
Yibing Liu 已提交
17 18 19
import os
import sys
import numpy as np
0
0YuanZhang0 已提交
20 21
import argparse
import collections
Y
Yibing Liu 已提交
22 23
import paddle.fluid as fluid

0
0YuanZhang0 已提交
24 25
import dgu.reader as reader
from dgu_net import create_net
26
import dgu.define_paradigm as define_paradigm
0
0YuanZhang0 已提交
27
import dgu.define_predict_pack as define_predict_pack
Y
Yibing Liu 已提交
28

0
0YuanZhang0 已提交
29 30 31
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
32
from dgu.utils.py23 import tab_tok, rt_tok
Y
Yibing Liu 已提交
33

P
pkpk 已提交
34

35
def do_predict(args):
0
0YuanZhang0 已提交
36
    """predict function"""
Y
Yibing Liu 已提交
37 38 39 40 41 42 43 44 45 46

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)
    pred_inst = define_predict_pack.DefinePredict()
    pred_func = getattr(pred_inst, pred_inst.task_map[task_name])

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
P
pkpk 已提交
47
        'atis_slot': reader.ATISSlotProcessor,
Y
Yibing Liu 已提交
48
        'atis_intent': reader.ATISIntentProcessor,
P
pkpk 已提交
49
        'dstc2': reader.DSTC2Processor,
Y
Yibing Liu 已提交
50 51
    }

0
0YuanZhang0 已提交
52 53
    test_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
Y
Yibing Liu 已提交
54

0
0YuanZhang0 已提交
55 56 57
    with fluid.program_guard(test_prog, startup_prog):
        test_prog.random_seed = args.random_seed
        startup_prog.random_seed = args.random_seed
Y
Yibing Liu 已提交
58 59

        with fluid.unique_name.guard():
0
0YuanZhang0 已提交
60 61 62 63

            # define inputs of the network
            num_labels = len(processors[task_name].get_labels())

0
0YuanZhang0 已提交
64
            src_ids = fluid.data(
65
                name='src_ids', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
66
            pos_ids = fluid.data(
67
                name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
68
            sent_ids = fluid.data(
69
                name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
70
            input_mask = fluid.data(
71
                name='input_mask',
0
0YuanZhang0 已提交
72
                shape=[-1, args.max_seq_len, 1],
73 74
                dtype='float32')
            if args.task_name == 'atis_slot':
0
0YuanZhang0 已提交
75
                labels = fluid.data(
76
                    name='labels', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
77
            elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
0
0YuanZhang0 已提交
78
                labels = fluid.data(
79 80 81 82
                    name='labels', shape=[-1, num_labels], dtype='int64')
            else:
                labels = fluid.data(name='labels', shape=[-1, 1], dtype='int64')

0
0YuanZhang0 已提交
83 84
            input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
            input_field = InputField(input_inst)
85
            data_reader = fluid.io.DataLoader.from_generator(
86 87
                feed_list=input_inst, capacity=4, iterable=False)

0
0YuanZhang0 已提交
88
            results = create_net(
89 90 91 92 93
                is_training=False,
                model_input=input_field,
                num_labels=num_labels,
                paradigm_inst=paradigm_inst,
                args=args)
0
0YuanZhang0 已提交
94 95 96 97 98 99

            probs = results.get("probs", None)
            fetch_list = [probs.name]

    #for_test is True if change the is_test attribute of operators to True
    test_prog = test_prog.clone(for_test=True)
Y
Yibing Liu 已提交
100 101

    if args.use_cuda:
0
0YuanZhang0 已提交
102
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
Y
Yibing Liu 已提交
103 104 105 106
    else:
        place = fluid.CPUPlace()

    exe = fluid.Executor(place)
0
0YuanZhang0 已提交
107
    exe.run(startup_prog)
Y
Yibing Liu 已提交
108

109
    assert (args.init_from_params)
Y
Yibing Liu 已提交
110

0
0YuanZhang0 已提交
111
    if args.init_from_params:
112
        fluid.load(test_prog, args.init_from_params)
Y
Yibing Liu 已提交
113

0
0YuanZhang0 已提交
114
    compiled_test_prog = fluid.CompiledProgram(test_prog)
115

0
0YuanZhang0 已提交
116 117 118 119 120 121 122 123
    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
                                      do_lower_case=args.do_lower_case,
                                      in_tokens=args.in_tokens,
                                      task_name=task_name,
                                      random_seed=args.random_seed)
    batch_generator = processor.data_generator(
124
        batch_size=args.batch_size, phase='test', shuffle=False)
0
0YuanZhang0 已提交
125

126
    data_reader.set_batch_generator(batch_generator, places=place)
0
0YuanZhang0 已提交
127
    data_reader.start()
128

Y
Yibing Liu 已提交
129
    all_results = []
130 131
    while True:
        try:
0
0YuanZhang0 已提交
132
            results = exe.run(compiled_test_prog, fetch_list=fetch_list)
Y
Yibing Liu 已提交
133
            all_results.extend(results[0])
134
        except fluid.core.EOFException:
0
0YuanZhang0 已提交
135
            data_reader.reset()
Y
Yibing Liu 已提交
136 137 138
            break

    np.set_printoptions(precision=4, suppress=True)
0
0YuanZhang0 已提交
139
    print("Write the predicted results into the output_prediction_file")
140

0
0YuanZhang0 已提交
141
    fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
142
    if task_name not in ['atis_slot']:
0
0YuanZhang0 已提交
143 144
        for index, result in enumerate(all_results):
            tags = pred_func(result)
145
            fw.write("%s%s%s%s" % (index, tab_tok, tags, rt_tok))
0
0YuanZhang0 已提交
146 147 148
    else:
        tags = pred_func(all_results, args.max_seq_len)
        for index, tag in enumerate(tags):
149
            fw.write("%s%s%s%s" % (index, tab_tok, tag, rt_tok))
P
pkpk 已提交
150

Y
Yibing Liu 已提交
151

0
0YuanZhang0 已提交
152
if __name__ == "__main__":
Y
Yibing Liu 已提交
153

0
0YuanZhang0 已提交
154 155 156
    args = PDConfig(yaml_file="./data/config/dgu.yaml")
    args.build()
    args.Print()
P
pkpk 已提交
157 158 159

    check_cuda(args.use_cuda)

0
0YuanZhang0 已提交
160
    do_predict(args)