提交 3584aeec 编写于 作者: 王肖 提交者: Xiaoyao Xi

fix dgu text encoding(#4028)

上级 97b6d5cf
...@@ -23,6 +23,8 @@ import numpy as np ...@@ -23,6 +23,8 @@ import numpy as np
from dgu import tokenization from dgu import tokenization
from dgu.batching import prepare_batch_data from dgu.batching import prepare_batch_data
reload(sys)
sys.setdefaultencoding('utf-8')
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
......
...@@ -73,11 +73,11 @@ class ATIS(object): ...@@ -73,11 +73,11 @@ class ATIS(object):
if example[1] not in self.intent_dict: if example[1] not in self.intent_dict:
self.intent_dict[example[1]] = self.intent_id self.intent_dict[example[1]] = self.intent_id
self.intent_id += 1 self.intent_id += 1
fw.write("%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower())) fw.write(u"%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower()))
fw = io.open(self.map_tag_intent, 'w', encoding="utf8") fw = io.open(self.map_tag_intent, 'w', encoding="utf8")
for tag in self.intent_dict: for tag in self.intent_dict:
fw.write("%s\t%s\n" % (tag, self.intent_dict[tag])) fw.write(u"%s\t%s\n" % (tag, self.intent_dict[tag]))
def _parser_slot_data(self, examples, data_type): def _parser_slot_data(self, examples, data_type):
""" """
...@@ -119,11 +119,11 @@ class ATIS(object): ...@@ -119,11 +119,11 @@ class ATIS(object):
if entities[-1]['end'] < len(text): if entities[-1]['end'] < len(text):
suffix_num = len(text[entities[-1]['end']:].strip().split()) suffix_num = len(text[entities[-1]['end']:].strip().split())
tags.extend([str(self.slot_dict['O'])] * suffix_num) tags.extend([str(self.slot_dict['O'])] * suffix_num)
fw.write("%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8'))) fw.write(u"%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8')))
fw = io.open(self.map_tag_slot, 'w', encoding="utf8") fw = io.open(self.map_tag_slot, 'w', encoding="utf8")
for slot in self.slot_dict: for slot in self.slot_dict:
fw.write("%s\t%s\n" % (slot, self.slot_dict[slot])) fw.write(u"%s\t%s\n" % (slot, self.slot_dict[slot]))
def get_train_dataset(self): def get_train_dataset(self):
""" """
......
...@@ -106,8 +106,8 @@ class DSTC2(object): ...@@ -106,8 +106,8 @@ class DSTC2(object):
out = "%s\t%s\1%s\t%s" % (session_id, mach, user, labels_ids) out = "%s\t%s\1%s\t%s" % (session_id, mach, user, labels_ids)
user_asr = log_turn['input']['live']['asr-hyps'][0]['asr-hyp'].strip() user_asr = log_turn['input']['live']['asr-hyps'][0]['asr-hyp'].strip()
out_asr = "%s\t%s\1%s\t%s" % (session_id, mach, user_asr, labels_ids) out_asr = "%s\t%s\1%s\t%s" % (session_id, mach, user_asr, labels_ids)
fw.write("%s\n" % out.encode('utf8')) fw.write(u"%s\n" % out.encode('utf8'))
fw_asr.write("%s\n" % out_asr.encode('utf8')) fw_asr.write(u"%s\n" % out_asr.encode('utf8'))
def get_train_dataset(self): def get_train_dataset(self):
""" """
...@@ -133,7 +133,7 @@ class DSTC2(object): ...@@ -133,7 +133,7 @@ class DSTC2(object):
""" """
fw = io.open(self.map_tag, 'w', encoding="utf8") fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict: for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem])) fw.write(u"%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self): def main(self):
""" """
......
...@@ -121,7 +121,7 @@ class MRDA(object): ...@@ -121,7 +121,7 @@ class MRDA(object):
caller = elem.split('_')[0].split('-')[-1] caller = elem.split('_')[0].split('-')[-1]
conv_no = elem.split('_')[0].split('-')[0] conv_no = elem.split('_')[0].split('-')[0]
out = "%s\t%s\t%s\t%s" % (conv_no, self.map_tag_dict[tag], caller, v_trans[0]) out = "%s\t%s\t%s\t%s" % (conv_no, self.map_tag_dict[tag], caller, v_trans[0])
fw.write("%s\n" % out) fw.write(u"%s\n" % out)
def get_train_dataset(self): def get_train_dataset(self):
""" """
...@@ -147,7 +147,7 @@ class MRDA(object): ...@@ -147,7 +147,7 @@ class MRDA(object):
""" """
fw = io.open(self.map_tag, 'w', encoding="utf8") fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict: for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem])) fw.write(u"%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self): def main(self):
""" """
......
...@@ -69,7 +69,7 @@ class SWDA(object): ...@@ -69,7 +69,7 @@ class SWDA(object):
idx += 1 idx += 1
continue continue
out = self._parser_utterence(r) out = self._parser_utterence(r)
fw.write("%s\n" % out) fw.write(u"%s\n" % out)
def _clean_text(self, text): def _clean_text(self, text):
""" """
...@@ -213,7 +213,7 @@ class SWDA(object): ...@@ -213,7 +213,7 @@ class SWDA(object):
""" """
fw = io.open(self.map_tag, 'w', encoding='utf8') fw = io.open(self.map_tag, 'w', encoding='utf8')
for elem in self.map_tag_dict: for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem])) fw.write(u"%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self): def main(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册