提交 6d611bb8 编写于 作者: 1024的传说's avatar 1024的传说 提交者: Guo Sheng

fix utf-8 error in python2 when predict (#3721)

上级 5a6f6822
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
if sys.version[0] == '2':
rt_tok = u'\n'
tab_tok = u'\t'
space_tok = u' '
else:
rt_tok = '\n'
tab_tok = '\t'
space_tok = ' '
......@@ -24,16 +24,17 @@ import paddle.fluid as fluid
import dgu.reader as reader
from dgu_net import create_net
import dgu.define_paradigm as define_paradigm
import dgu.define_paradigm as define_paradigm
import dgu.define_predict_pack as define_predict_pack
from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
from dgu.utils.py23 import tab_tok, rt_tok
def do_predict(args):
def do_predict(args):
"""predict function"""
task_name = args.task_name.lower()
......@@ -63,34 +64,35 @@ def do_predict(args):
num_labels = len(processors[task_name].get_labels())
src_ids = fluid.data(
name='src_ids', shape=[-1, args.max_seq_len], dtype='int64')
name='src_ids', shape=[-1, args.max_seq_len], dtype='int64')
pos_ids = fluid.data(
name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64')
name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64')
sent_ids = fluid.data(
name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
input_mask = fluid.data(
name='input_mask', shape=[-1, args.max_seq_len], dtype='float32')
if args.task_name == 'atis_slot':
name='input_mask',
shape=[-1, args.max_seq_len],
dtype='float32')
if args.task_name == 'atis_slot':
labels = fluid.data(
name='labels', shape=[-1, args.max_seq_len], dtype='int64')
name='labels', shape=[-1, args.max_seq_len], dtype='int64')
elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']:
labels = fluid.data(
name='labels', shape=[-1, num_labels], dtype='int64')
else:
labels = fluid.data(
name='labels', shape=[-1, 1], dtype='int64')
name='labels', shape=[-1, num_labels], dtype='int64')
else:
labels = fluid.data(name='labels', shape=[-1, 1], dtype='int64')
input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(feed_list=input_inst,
capacity=4, iterable=False)
data_reader = fluid.io.PyReader(
feed_list=input_inst, capacity=4, iterable=False)
results = create_net(
is_training=False,
model_input=input_field,
num_labels=num_labels,
paradigm_inst=paradigm_inst,
args=args)
is_training=False,
model_input=input_field,
num_labels=num_labels,
paradigm_inst=paradigm_inst,
args=args)
probs = results.get("probs", None)
......@@ -117,7 +119,7 @@ def do_predict(args):
save_load_io.init_from_pretrain_model(args, exe, test_prog)
compiled_test_prog = fluid.CompiledProgram(test_prog)
processor = processors[task_name](data_dir=args.data_dir,
vocab_path=args.vocab_path,
max_seq_len=args.max_seq_len,
......@@ -126,34 +128,32 @@ def do_predict(args):
task_name=task_name,
random_seed=args.random_seed)
batch_generator = processor.data_generator(
batch_size=args.batch_size,
phase='test',
shuffle=False)
batch_size=args.batch_size, phase='test', shuffle=False)
data_reader.decorate_batch_generator(batch_generator)
data_reader.decorate_batch_generator(batch_generator)
data_reader.start()
all_results = []
while True:
try:
while True:
try:
results = exe.run(compiled_test_prog, fetch_list=fetch_list)
all_results.extend(results[0])
except fluid.core.EOFException:
except fluid.core.EOFException:
data_reader.reset()
break
np.set_printoptions(precision=4, suppress=True)
print("Write the predicted results into the output_prediction_file")
fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
if task_name not in ['atis_slot']:
if task_name not in ['atis_slot']:
for index, result in enumerate(all_results):
tags = pred_func(result)
fw.write("%s\t%s\n" % (index, tags))
fw.write("%s%s%s%s" % (index, tab_tok, tags, rt_tok))
else:
tags = pred_func(all_results, args.max_seq_len)
for index, tag in enumerate(tags):
fw.write("%s\t%s\n" % (index, tag))
fw.write("%s%s%s%s" % (index, tab_tok, tag, rt_tok))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册