提交 3584aeec 编写于 作者: 王肖 提交者: Xiaoyao Xi

fix dgu text encoding(#4028)

上级 97b6d5cf
......@@ -23,6 +23,8 @@ import numpy as np
from dgu import tokenization
from dgu.batching import prepare_batch_data
reload(sys)
sys.setdefaultencoding('utf-8')
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
......
......@@ -73,11 +73,11 @@ class ATIS(object):
if example[1] not in self.intent_dict:
self.intent_dict[example[1]] = self.intent_id
self.intent_id += 1
fw.write("%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower()))
fw.write(u"%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower()))
fw = io.open(self.map_tag_intent, 'w', encoding="utf8")
for tag in self.intent_dict:
fw.write("%s\t%s\n" % (tag, self.intent_dict[tag]))
fw.write(u"%s\t%s\n" % (tag, self.intent_dict[tag]))
def _parser_slot_data(self, examples, data_type):
"""
......@@ -119,11 +119,11 @@ class ATIS(object):
if entities[-1]['end'] < len(text):
suffix_num = len(text[entities[-1]['end']:].strip().split())
tags.extend([str(self.slot_dict['O'])] * suffix_num)
fw.write("%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8')))
fw.write(u"%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8')))
fw = io.open(self.map_tag_slot, 'w', encoding="utf8")
for slot in self.slot_dict:
fw.write("%s\t%s\n" % (slot, self.slot_dict[slot]))
fw.write(u"%s\t%s\n" % (slot, self.slot_dict[slot]))
def get_train_dataset(self):
"""
......
......@@ -106,8 +106,8 @@ class DSTC2(object):
out = "%s\t%s\1%s\t%s" % (session_id, mach, user, labels_ids)
user_asr = log_turn['input']['live']['asr-hyps'][0]['asr-hyp'].strip()
out_asr = "%s\t%s\1%s\t%s" % (session_id, mach, user_asr, labels_ids)
fw.write("%s\n" % out.encode('utf8'))
fw_asr.write("%s\n" % out_asr.encode('utf8'))
fw.write(u"%s\n" % out.encode('utf8'))
fw_asr.write(u"%s\n" % out_asr.encode('utf8'))
def get_train_dataset(self):
"""
......@@ -133,7 +133,7 @@ class DSTC2(object):
"""
fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
fw.write(u"%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self):
"""
......
......@@ -121,7 +121,7 @@ class MRDA(object):
caller = elem.split('_')[0].split('-')[-1]
conv_no = elem.split('_')[0].split('-')[0]
out = "%s\t%s\t%s\t%s" % (conv_no, self.map_tag_dict[tag], caller, v_trans[0])
fw.write("%s\n" % out)
fw.write(u"%s\n" % out)
def get_train_dataset(self):
"""
......@@ -147,7 +147,7 @@ class MRDA(object):
"""
fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
fw.write(u"%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self):
"""
......
......@@ -69,7 +69,7 @@ class SWDA(object):
idx += 1
continue
out = self._parser_utterence(r)
fw.write("%s\n" % out)
fw.write(u"%s\n" % out)
def _clean_text(self, text):
"""
......@@ -213,7 +213,7 @@ class SWDA(object):
"""
fw = io.open(self.map_tag, 'w', encoding='utf8')
for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
fw.write(u"%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册