From 6d611bb8b1aec78547c76eb2cde670c65180e33d Mon Sep 17 00:00:00 2001 From: Xing Wu <1160386409@qq.com> Date: Wed, 23 Oct 2019 13:31:05 +0800 Subject: [PATCH] fix utf-8 error in python2 when predict (#3721) --- .../dgu/utils/py23.py | 25 +++++++ .../dialogue_general_understanding/predict.py | 68 +++++++++---------- 2 files changed, 59 insertions(+), 34 deletions(-) create mode 100644 PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py diff --git a/PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py new file mode 100644 index 00000000..0d84ddfa --- /dev/null +++ b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +if sys.version[0] == '2': + rt_tok = u'\n' + tab_tok = u'\t' + space_tok = u' ' +else: + rt_tok = '\n' + tab_tok = '\t' + space_tok = ' ' diff --git a/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py index bab34ed2..8cc64f1b 100644 --- a/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py +++ b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py @@ -24,16 +24,17 @@ import paddle.fluid as fluid import dgu.reader as reader from dgu_net import create_net -import dgu.define_paradigm as define_paradigm +import dgu.define_paradigm as define_paradigm import dgu.define_predict_pack as define_predict_pack from dgu.utils.configure import PDConfig from dgu.utils.input_field import InputField from dgu.utils.model_check import check_cuda import dgu.utils.save_load_io as save_load_io +from dgu.utils.py23 import tab_tok, rt_tok -def do_predict(args): +def do_predict(args): """predict function""" task_name = args.task_name.lower() @@ -63,34 +64,35 @@ def do_predict(args): num_labels = len(processors[task_name].get_labels()) src_ids = fluid.data( - name='src_ids', shape=[-1, args.max_seq_len], dtype='int64') + name='src_ids', shape=[-1, args.max_seq_len], dtype='int64') pos_ids = fluid.data( - name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64') + name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64') sent_ids = fluid.data( - name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64') + name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64') input_mask = fluid.data( - name='input_mask', shape=[-1, args.max_seq_len], dtype='float32') - if args.task_name == 'atis_slot': + name='input_mask', + shape=[-1, args.max_seq_len], + dtype='float32') + if args.task_name == 'atis_slot': labels = fluid.data( - name='labels', shape=[-1, args.max_seq_len], dtype='int64') + name='labels', shape=[-1, args.max_seq_len], dtype='int64') elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']: labels = fluid.data( - name='labels', shape=[-1, num_labels], dtype='int64') - else: - labels = fluid.data( - name='labels', shape=[-1, 1], dtype='int64') - + name='labels', shape=[-1, num_labels], dtype='int64') + else: + labels = fluid.data(name='labels', shape=[-1, 1], dtype='int64') + input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels] input_field = InputField(input_inst) - data_reader = fluid.io.PyReader(feed_list=input_inst, - capacity=4, iterable=False) - + data_reader = fluid.io.PyReader( + feed_list=input_inst, capacity=4, iterable=False) + results = create_net( - is_training=False, - model_input=input_field, - num_labels=num_labels, - paradigm_inst=paradigm_inst, - args=args) + is_training=False, + model_input=input_field, + num_labels=num_labels, + paradigm_inst=paradigm_inst, + args=args) probs = results.get("probs", None) @@ -117,7 +119,7 @@ def do_predict(args): save_load_io.init_from_pretrain_model(args, exe, test_prog) compiled_test_prog = fluid.CompiledProgram(test_prog) - + processor = processors[task_name](data_dir=args.data_dir, vocab_path=args.vocab_path, max_seq_len=args.max_seq_len, @@ -126,34 +128,32 @@ def do_predict(args): task_name=task_name, random_seed=args.random_seed) batch_generator = processor.data_generator( - batch_size=args.batch_size, - phase='test', - shuffle=False) + batch_size=args.batch_size, phase='test', shuffle=False) - data_reader.decorate_batch_generator(batch_generator) + data_reader.decorate_batch_generator(batch_generator) data_reader.start() - + all_results = [] - while True: - try: + while True: + try: results = exe.run(compiled_test_prog, fetch_list=fetch_list) all_results.extend(results[0]) - except fluid.core.EOFException: + except fluid.core.EOFException: data_reader.reset() break np.set_printoptions(precision=4, suppress=True) print("Write the predicted results into the output_prediction_file") - + fw = io.open(args.output_prediction_file, 'w', encoding="utf8") - if task_name not in ['atis_slot']: + if task_name not in ['atis_slot']: for index, result in enumerate(all_results): tags = pred_func(result) - fw.write("%s\t%s\n" % (index, tags)) + fw.write("%s%s%s%s" % (index, tab_tok, tags, rt_tok)) else: tags = pred_func(all_results, args.max_seq_len) for index, tag in enumerate(tags): - fw.write("%s\t%s\n" % (index, tag)) + fw.write("%s%s%s%s" % (index, tab_tok, tag, rt_tok)) if __name__ == "__main__": -- GitLab