提交 6d611bb8 编写于 作者: 1024的传说's avatar 1024的传说 提交者: Guo Sheng

fix utf-8 error in python2 when predict (#3721)

上级 5a6f6822
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
if sys.version[0] == '2':
rt_tok = u'\n'
tab_tok = u'\t'
space_tok = u' '
else:
rt_tok = '\n'
tab_tok = '\t'
space_tok = ' '
......@@ -31,6 +31,7 @@ from dgu.utils.configure import PDConfig
from dgu.utils.input_field import InputField
from dgu.utils.model_check import check_cuda
import dgu.utils.save_load_io as save_load_io
from dgu.utils.py23 import tab_tok, rt_tok
def do_predict(args):
......@@ -69,7 +70,9 @@ def do_predict(args):
sent_ids = fluid.data(
name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64')
input_mask = fluid.data(
name='input_mask', shape=[-1, args.max_seq_len], dtype='float32')
name='input_mask',
shape=[-1, args.max_seq_len],
dtype='float32')
if args.task_name == 'atis_slot':
labels = fluid.data(
name='labels', shape=[-1, args.max_seq_len], dtype='int64')
......@@ -77,13 +80,12 @@ def do_predict(args):
labels = fluid.data(
name='labels', shape=[-1, num_labels], dtype='int64')
else:
labels = fluid.data(
name='labels', shape=[-1, 1], dtype='int64')
labels = fluid.data(name='labels', shape=[-1, 1], dtype='int64')
input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels]
input_field = InputField(input_inst)
data_reader = fluid.io.PyReader(feed_list=input_inst,
capacity=4, iterable=False)
data_reader = fluid.io.PyReader(
feed_list=input_inst, capacity=4, iterable=False)
results = create_net(
is_training=False,
......@@ -126,9 +128,7 @@ def do_predict(args):
task_name=task_name,
random_seed=args.random_seed)
batch_generator = processor.data_generator(
batch_size=args.batch_size,
phase='test',
shuffle=False)
batch_size=args.batch_size, phase='test', shuffle=False)
data_reader.decorate_batch_generator(batch_generator)
data_reader.start()
......@@ -149,11 +149,11 @@ def do_predict(args):
if task_name not in ['atis_slot']:
for index, result in enumerate(all_results):
tags = pred_func(result)
fw.write("%s\t%s\n" % (index, tags))
fw.write("%s%s%s%s" % (index, tab_tok, tags, rt_tok))
else:
tags = pred_func(all_results, args.max_seq_len)
for index, tag in enumerate(tags):
fw.write("%s\t%s\n" % (index, tag))
fw.write("%s%s%s%s" % (index, tab_tok, tag, rt_tok))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册