predict.py 5.8 KB
Newer Older
0
0YuanZhang0 已提交
1
# -*- coding: utf-8 -*-
0
0YuanZhang0 已提交
2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. 
Y
Yibing Liu 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

0
0YuanZhang0 已提交
16
import io
Y
Yibing Liu 已提交
17 18 19
import os
import sys
import numpy as np
0
0YuanZhang0 已提交
20 21
import argparse
import collections
Y
Yibing Liu 已提交
22 23 24
import paddle
import paddle.fluid as fluid

0
0YuanZhang0 已提交
25 26
import dgu.reader as reader
from dgu_net import create_net
27
import dgu.define_paradigm as define_paradigm
0
0YuanZhang0 已提交
28
import dgu.define_predict_pack as define_predict_pack
Y
Yibing Liu 已提交
29

0
0YuanZhang0 已提交
30 31 32 33
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
34
from dgu.utils.py23 import tab_tok, rt_tok
Y
Yibing Liu 已提交
35

P
pkpk 已提交
36

37
def do_predict(args):
0
0YuanZhang0 已提交
38
    """predict function"""
Y
Yibing Liu 已提交
39 40 41 42 43 44 45 46 47 48

    task_name = args.task_name.lower()
    paradigm_inst = define_paradigm.Paradigm(task_name)
    pred_inst = define_predict_pack.DefinePredict()
    pred_func = getattr(pred_inst, pred_inst.task_map[task_name])

    processors = {
        'udc': reader.UDCProcessor,
        'swda': reader.SWDAProcessor,
        'mrda': reader.MRDAProcessor,
P
pkpk 已提交
49
        'atis_slot': reader.ATISSlotProcessor,
Y
Yibing Liu 已提交
50
        'atis_intent': reader.ATISIntentProcessor,
P
pkpk 已提交
51
        'dstc2': reader.DSTC2Processor,
Y
Yibing Liu 已提交
52 53
    }

0
0YuanZhang0 已提交
54 55
    test_prog = fluid.default_main_program()
    startup_prog = fluid.default_startup_program()
Y
Yibing Liu 已提交
56

0
0YuanZhang0 已提交
57 58 59
    with fluid.program_guard(test_prog, startup_prog):
        test_prog.random_seed = args.random_seed
        startup_prog.random_seed = args.random_seed
Y
Yibing Liu 已提交
60 61

        with fluid.unique_name.guard():
0
0YuanZhang0 已提交
62 63 64 65

            # define inputs of the network
            num_labels = len(processors[task_name].get_labels())

0
0YuanZhang0 已提交
66
            src_ids = fluid.data(
67
                name='src_ids', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
68
            pos_ids = fluid.data(
69
                name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
70
            sent_ids = fluid.data(
71
                name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
72
            input_mask = fluid.data(
73
                name='input_mask',
0
0YuanZhang0 已提交
74
                shape=[-1, args.max_seq_len, 1],
75 76
                dtype='float32')
            if args.task_name == 'atis_slot':
0
0YuanZhang0 已提交
77
                labels = fluid.data(
78
                    name='labels', shape=[-1, args.max_seq_len], dtype='int64')
0
0YuanZhang0 已提交
79
            elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
0
0YuanZhang0 已提交
80
                labels = fluid.data(
81 82 83 84
                    name='labels', shape=[-1, num_labels], dtype='int64')
            else:
                labels = fluid.data(name='labels', shape=[-1, 1], dtype='int64')

0
0YuanZhang0 已提交
85 86
            input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
            input_field = InputField(input_inst)
87 88 89
            data_reader = fluid.io.PyReader(
                feed_list=input_inst, capacity=4, iterable=False)

0
0YuanZhang0 已提交
90
            results = create_net(
91 92 93 94 95
                is_training=False,
                model_input=input_field,
                num_labels=num_labels,
                paradigm_inst=paradigm_inst,
                args=args)
0
0YuanZhang0 已提交
96 97 98 99 100 101 102 103 104

            probs = results.get("probs", None)

            probs.persistable = True

            fetch_list = [probs.name]

    #for_test is True if change the is_test attribute of operators to True
    test_prog = test_prog.clone(for_test=True)
Y
Yibing Liu 已提交
105 106

    if args.use_cuda:
0
0YuanZhang0 已提交
107
        place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
Y
Yibing Liu 已提交
108 109 110 111
    else:
        place = fluid.CPUPlace()

    exe = fluid.Executor(place)
0
0YuanZhang0 已提交
112
    exe.run(startup_prog)
Y
Yibing Liu 已提交
113

0
0YuanZhang0 已提交
114
    assert (args.init_from_params) or (args.init_from_pretrain_model)
Y
Yibing Liu 已提交
115

0
0YuanZhang0 已提交
116 117 118 119
    if args.init_from_params:
        save_load_io.init_from_params(args, exe, test_prog)
    if args.init_from_pretrain_model:
        save_load_io.init_from_pretrain_model(args, exe, test_prog)
Y
Yibing Liu 已提交
120

0
0YuanZhang0 已提交
121
    compiled_test_prog = fluid.CompiledProgram(test_prog)
122

0
0YuanZhang0 已提交
123 124 125 126 127 128 129 130
    processor = processors[task_name](data_dir=args.data_dir,
                                      vocab_path=args.vocab_path,
                                      max_seq_len=args.max_seq_len,
                                      do_lower_case=args.do_lower_case,
                                      in_tokens=args.in_tokens,
                                      task_name=task_name,
                                      random_seed=args.random_seed)
    batch_generator = processor.data_generator(
131
        batch_size=args.batch_size, phase='test', shuffle=False)
0
0YuanZhang0 已提交
132

133
    data_reader.decorate_batch_generator(batch_generator)
0
0YuanZhang0 已提交
134
    data_reader.start()
135

Y
Yibing Liu 已提交
136
    all_results = []
137 138
    while True:
        try:
0
0YuanZhang0 已提交
139
            results = exe.run(compiled_test_prog, fetch_list=fetch_list)
Y
Yibing Liu 已提交
140
            all_results.extend(results[0])
141
        except fluid.core.EOFException:
0
0YuanZhang0 已提交
142
            data_reader.reset()
Y
Yibing Liu 已提交
143 144 145
            break

    np.set_printoptions(precision=4, suppress=True)
0
0YuanZhang0 已提交
146
    print("Write the predicted results into the output_prediction_file")
147

0
0YuanZhang0 已提交
148
    fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
149
    if task_name not in ['atis_slot']:
0
0YuanZhang0 已提交
150 151
        for index, result in enumerate(all_results):
            tags = pred_func(result)
152
            fw.write("%s%s%s%s" % (index, tab_tok, tags, rt_tok))
0
0YuanZhang0 已提交
153 154 155
    else:
        tags = pred_func(all_results, args.max_seq_len)
        for index, tag in enumerate(tags):
156
            fw.write("%s%s%s%s" % (index, tab_tok, tag, rt_tok))
P
pkpk 已提交
157

Y
Yibing Liu 已提交
158

0
0YuanZhang0 已提交
159
if __name__ == "__main__":
Y
Yibing Liu 已提交
160

0
0YuanZhang0 已提交
161 162 163
    args = PDConfig(yaml_file="./data/config/dgu.yaml")
    args.build()
    args.Print()
P
pkpk 已提交
164 165 166

    check_cuda(args.use_cuda)

0
0YuanZhang0 已提交
167
    do_predict(args)