未验证 提交 775741f3 编写于 作者: Z zhang wenhui 提交者: GitHub

fix python2&3 encode in gru4rec&tagspace (#4072)

上级 a58b30c6
......@@ -3,6 +3,7 @@ import six
import collections
import os
import sys
import io
if six.PY2:
reload(sys)
sys.setdefaultencoding('utf-8')
......@@ -30,11 +31,11 @@ def build_dict(min_word_freq=0, train_dir="", test_dir=""):
word_freq = collections.defaultdict(int)
files = os.listdir(train_dir)
for fi in files:
with open(os.path.join(train_dir, fi), "r") as f:
with io.open(os.path.join(train_dir, fi), "r") as f:
word_freq = word_count(f, word_freq)
files = os.listdir(test_dir)
for fi in files:
with open(os.path.join(test_dir, fi), "r") as f:
with io.open(os.path.join(test_dir, fi), "r") as f:
word_freq = word_count(f, word_freq)
word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq]
......@@ -50,35 +51,44 @@ def write_paddle(word_idx, train_dir, test_dir, output_train_dir,
if not os.path.exists(output_train_dir):
os.mkdir(output_train_dir)
for fi in files:
with open(os.path.join(train_dir, fi), "r") as f:
with open(os.path.join(output_train_dir, fi), "w") as wf:
with io.open(os.path.join(train_dir, fi), "r") as f:
with io.open(os.path.join(output_train_dir, fi), "w") as wf:
for l in f:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
for w in l:
wf.write(str(w) + " ")
wf.write("\n")
wf.write(str2file(str(w) + " "))
wf.write(str2file("\n"))
files = os.listdir(test_dir)
if not os.path.exists(output_test_dir):
os.mkdir(output_test_dir)
for fi in files:
with open(os.path.join(test_dir, fi), "r") as f:
with open(os.path.join(output_test_dir, fi), "w") as wf:
with io.open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f:
with io.open(
os.path.join(output_test_dir, fi), "w",
encoding='utf-8') as wf:
for l in f:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
for w in l:
wf.write(str(w) + " ")
wf.write("\n")
wf.write(str2file(str(w) + " "))
wf.write(str2file("\n"))
def str2file(str):
if six.PY2:
return str.decode("utf-8")
else:
return str
def text2paddle(train_dir, test_dir, output_train_dir, output_test_dir,
output_vocab):
vocab = build_dict(0, train_dir, test_dir)
with open(output_vocab, "w", encoding='utf-8') as wf:
wf.write(str(len(vocab)) + "\n")
#wf.write(str(vocab))
print("vocab size:", str(len(vocab)))
with io.open(output_vocab, "w", encoding='utf-8') as wf:
wf.write(str2file(str(len(vocab)) + "\n"))
write_paddle(vocab, train_dir, test_dir, output_train_dir, output_test_dir)
......
......@@ -2,29 +2,27 @@ import sys
import six
import collections
import os
import csv
import io
import re
import sys
if six.PY2:
reload(sys)
sys.setdefaultencoding('utf-8')
def word_count(column_num, input_file, word_freq=None):
def word_count(input_file, word_freq=None):
"""
compute word count from corpus
"""
if word_freq is None:
word_freq = collections.defaultdict(int)
data_file = csv.reader(input_file)
for row in data_file:
for w in re.split(r'\W+', row[column_num].strip()):
for l in input_file:
for w in l.strip().split():
word_freq[w] += 1
return word_freq
def build_dict(column_num=2, min_word_freq=0, train_dir="", test_dir=""):
def build_dict(min_word_freq=0, train_dir="", test_dir=""):
"""
Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words.
......@@ -32,12 +30,12 @@ def build_dict(column_num=2, min_word_freq=0, train_dir="", test_dir=""):
word_freq = collections.defaultdict(int)
files = os.listdir(train_dir)
for fi in files:
with io.open(os.path.join(train_dir, fi), "r", encoding='utf-8') as f:
word_freq = word_count(column_num, f, word_freq)
with open(os.path.join(train_dir, fi), "r") as f:
word_freq = word_count(f, word_freq)
files = os.listdir(test_dir)
for fi in files:
with io.open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f:
word_freq = word_count(column_num, f, word_freq)
with open(os.path.join(test_dir, fi), "r") as f:
word_freq = word_count(f, word_freq)
word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
......@@ -46,7 +44,7 @@ def build_dict(column_num=2, min_word_freq=0, train_dir="", test_dir=""):
return word_idx
def write_paddle(text_idx, tag_idx, train_dir, test_dir, output_train_dir,
def write_paddle(word_idx, train_dir, test_dir, output_train_dir,
output_test_dir):
files = os.listdir(train_dir)
if not os.path.exists(output_train_dir):
......@@ -54,13 +52,9 @@ def write_paddle(text_idx, tag_idx, train_dir, test_dir, output_train_dir,
for fi in files:
with open(os.path.join(train_dir, fi), "r") as f:
with open(os.path.join(output_train_dir, fi), "w") as wf:
data_file = csv.reader(f)
for row in data_file:
tag_raw = re.split(r'\W+', row[0].strip())
pos_index = tag_idx.get(tag_raw[0])
wf.write(str(pos_index) + ",")
text_raw = re.split(r'\W+', row[2].strip())
l = [text_idx.get(w) for w in text_raw]
for l in f:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
for w in l:
wf.write(str(w) + " ")
wf.write("\n")
......@@ -71,40 +65,27 @@ def write_paddle(text_idx, tag_idx, train_dir, test_dir, output_train_dir,
for fi in files:
with open(os.path.join(test_dir, fi), "r") as f:
with open(os.path.join(output_test_dir, fi), "w") as wf:
data_file = csv.reader(f)
for row in data_file:
tag_raw = re.split(r'\W+', row[0].strip())
pos_index = tag_idx.get(tag_raw[0])
wf.write(str(pos_index) + ",")
text_raw = re.split(r'\W+', row[2].strip())
l = [text_idx.get(w) for w in text_raw]
for l in f:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
for w in l:
wf.write(str(w) + " ")
wf.write("\n")
def text2paddle(train_dir, test_dir, output_train_dir, output_test_dir,
output_vocab_text, output_vocab_tag):
print("start constuct word dict")
vocab_text = build_dict(2, 0, train_dir, test_dir)
with io.open(output_vocab_text, "w", encoding='utf-8') as wf:
wf.write((str(len(vocab_text)) + "\n").decode('utf-8'))
vocab_tag = build_dict(0, 0, train_dir, test_dir)
with io.open(output_vocab_tag, "w", encoding='utf-8') as wf:
wf.write((str(len(vocab_tag)) + "\n").decode('utf-8'))
print("construct word dict done\n")
write_paddle(vocab_text, vocab_tag, train_dir, test_dir, output_train_dir,
output_test_dir)
output_vocab):
vocab = build_dict(0, train_dir, test_dir)
with open(output_vocab, "w", encoding='utf-8') as wf:
wf.write(str(len(vocab)) + "\n")
#wf.write(str(vocab))
write_paddle(vocab, train_dir, test_dir, output_train_dir, output_test_dir)
train_dir = sys.argv[1]
test_dir = sys.argv[2]
output_train_dir = sys.argv[3]
output_test_dir = sys.argv[4]
output_vocab_text = sys.argv[5]
output_vocab_tag = sys.argv[6]
output_vocab = sys.argv[5]
text2paddle(train_dir, test_dir, output_train_dir, output_test_dir,
output_vocab_text, output_vocab_tag)
output_vocab)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册