未验证 提交 3082e40a 编写于 作者: Z zhang wenhui 提交者: GitHub

fix io.open in PaddleRec (#3717)



* fix io.open,test=develop
上级 0adeee11
......@@ -6,6 +6,7 @@ import numpy as np
import paddle.fluid as fluid
import paddle
import os
import io
def to_lodtensor(data, place):
......@@ -86,7 +87,7 @@ def to_lodtensor_bpr_test(raw_data, vocab_size, place):
def get_vocab_size(vocab_path):
with open(vocab_path, "r", encoding='utf-8') as rf:
with io.open(vocab_path, "r", encoding='utf-8') as rf:
line = rf.readline()
return int(line.strip())
......@@ -115,20 +116,22 @@ def prepare_data(file_dir,
file_dir, buffer_size, data_type=DataType.SEQ), batch_size)
return vocab_size, reader
def check_version():
"""
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
"""
......@@ -184,7 +187,8 @@ def reader_creator(file_dir, n, data_type):
def reader():
files = os.listdir(file_dir)
for fi in files:
with open(os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in f:
if DataType.SEQ == data_type:
l = l.strip().split()
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import random
import io
class Dataset:
......@@ -33,7 +34,7 @@ class YoochooseVocab(Vocab):
def load(self, filelist):
idx = 0
for f in filelist:
with open(f, "r", encoding='utf-8') as fin:
with io.open(f, "r", encoding='utf-8') as fin:
for line in fin:
group = line.strip().split()
for item in group:
......@@ -64,7 +65,7 @@ class YoochooseDataset(Dataset):
def _reader_creator(self, filelist, is_train):
def reader():
for f in filelist:
with open(f, 'r', encoding='utf-8') as fin:
with io.open(f, 'r', encoding='utf-8') as fin:
line_idx = 0
for line in fin:
ids = line.strip().split()
......
......@@ -4,10 +4,11 @@ import os
import logging
import paddle.fluid as fluid
import paddle
import io
def get_vocab_size(vocab_path):
with open(vocab_path, "r", encoding='utf-8') as rf:
with io.open(vocab_path, "r", encoding='utf-8') as rf:
line = rf.readline()
return int(line.strip())
......@@ -30,20 +31,22 @@ def construct_test_data(file_dir, vocab_path, batch_size):
test_reader = fluid.io.batch(y_data.test(files), batch_size=batch_size)
return test_reader, vocab_size
def check_version():
"""
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
def infer_data(raw_data, place):
data = [dat[0] for dat in raw_data]
......
......@@ -8,9 +8,11 @@ import numpy as np
import paddle.fluid as fluid
import paddle
import csv
import io
reload(sys)
sys.setdefaultencoding('utf-8')
def to_lodtensor(data, place):
""" convert to LODtensor """
seq_lens = [len(seq) for seq in data]
......@@ -26,25 +28,27 @@ def to_lodtensor(data, place):
res.set_lod([lod])
return res
def get_vocab_size(vocab_path):
with open(vocab_path, "r") as rf:
with io.open(vocab_path, "r") as rf:
line = rf.readline()
return int(line.strip())
def check_version():
"""
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
def prepare_data(file_dir,
......@@ -62,19 +66,25 @@ def prepare_data(file_dir,
reader = sort_batch(
paddle.reader.shuffle(
train(
file_dir, vocab_tag_size, neg_size,
buffer_size, data_type=DataType.SEQ),
file_dir,
vocab_tag_size,
neg_size,
buffer_size,
data_type=DataType.SEQ),
buf_size=buffer_size),
batch_size, batch_size * 20)
batch_size,
batch_size * 20)
else:
vocab_tag_size = get_vocab_size(vocab_tag_path)
vocab_text_size = 0
reader = sort_batch(
test(
file_dir, vocab_tag_size, buffer_size, data_type=DataType.SEQ),
batch_size, batch_size * 20)
batch_size,
batch_size * 20)
return vocab_text_size, vocab_tag_size, reader
def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
"""
Create a batched reader.
......@@ -124,11 +134,13 @@ def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
class DataType(object):
SEQ = 2
def train_reader_creator(file_dir, tag_size, neg_size, n, data_type):
def reader():
files = os.listdir(file_dir)
for fi in files:
with open(os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in f:
l = l.strip().split(",")
pos_index = int(l[0])
......@@ -140,7 +152,7 @@ def train_reader_creator(file_dir, tag_size, neg_size, n, data_type):
max_iter = 100
now_iter = 0
sum_n = 0
while(sum_n < neg_size) :
while (sum_n < neg_size):
now_iter += 1
if now_iter > max_iter:
print("error : only one class")
......@@ -152,13 +164,16 @@ def train_reader_creator(file_dir, tag_size, neg_size, n, data_type):
sum_n += 1
if n > 0 and len(text) > n: continue
yield text, pos_tag, neg_tag
return reader
def test_reader_creator(file_dir, tag_size, n, data_type):
def reader():
files = os.listdir(file_dir)
for fi in files:
with open(os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in f:
l = l.strip().split(",")
pos_index = int(l[0])
......@@ -170,11 +185,13 @@ def test_reader_creator(file_dir, tag_size, n, data_type):
tag = []
tag.append(ii)
yield text, tag, pos_tag
return reader
def train(train_dir, tag_size, neg_size, n, data_type=DataType.SEQ):
return train_reader_creator(train_dir, tag_size, neg_size, n, data_type)
def test(test_dir, tag_size, n, data_type=DataType.SEQ):
return test_reader_creator(test_dir, tag_size, n, data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册