未验证 提交 3082e40a 编写于 作者: Z zhang wenhui 提交者: GitHub

fix io.open in PaddleRec (#3717)



* fix io.open,test=develop
上级 0adeee11
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import os import os
import io
def to_lodtensor(data, place): def to_lodtensor(data, place):
...@@ -86,7 +87,7 @@ def to_lodtensor_bpr_test(raw_data, vocab_size, place): ...@@ -86,7 +87,7 @@ def to_lodtensor_bpr_test(raw_data, vocab_size, place):
def get_vocab_size(vocab_path): def get_vocab_size(vocab_path):
with open(vocab_path, "r", encoding='utf-8') as rf: with io.open(vocab_path, "r", encoding='utf-8') as rf:
line = rf.readline() line = rf.readline()
return int(line.strip()) return int(line.strip())
...@@ -115,20 +116,22 @@ def prepare_data(file_dir, ...@@ -115,20 +116,22 @@ def prepare_data(file_dir,
file_dir, buffer_size, data_type=DataType.SEQ), batch_size) file_dir, buffer_size, data_type=DataType.SEQ), batch_size)
return vocab_size, reader return vocab_size, reader
def check_version(): def check_version():
""" """
Log error and exit when the installed version of paddlepaddle is Log error and exit when the installed version of paddlepaddle is
not satisfied. not satisfied.
""" """
err = "PaddlePaddle version 1.6 or higher is required, " \ err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \ "or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \ "Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
def sort_batch(reader, batch_size, sort_group_size, drop_last=False): def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
""" """
...@@ -184,7 +187,8 @@ def reader_creator(file_dir, n, data_type): ...@@ -184,7 +187,8 @@ def reader_creator(file_dir, n, data_type):
def reader(): def reader():
files = os.listdir(file_dir) files = os.listdir(file_dir)
for fi in files: for fi in files:
with open(os.path.join(file_dir, fi), "r", encoding='utf-8') as f: with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in f: for l in f:
if DataType.SEQ == data_type: if DataType.SEQ == data_type:
l = l.strip().split() l = l.strip().split()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import random import random
import io
class Dataset: class Dataset:
...@@ -33,7 +34,7 @@ class YoochooseVocab(Vocab): ...@@ -33,7 +34,7 @@ class YoochooseVocab(Vocab):
def load(self, filelist): def load(self, filelist):
idx = 0 idx = 0
for f in filelist: for f in filelist:
with open(f, "r", encoding='utf-8') as fin: with io.open(f, "r", encoding='utf-8') as fin:
for line in fin: for line in fin:
group = line.strip().split() group = line.strip().split()
for item in group: for item in group:
...@@ -64,7 +65,7 @@ class YoochooseDataset(Dataset): ...@@ -64,7 +65,7 @@ class YoochooseDataset(Dataset):
def _reader_creator(self, filelist, is_train): def _reader_creator(self, filelist, is_train):
def reader(): def reader():
for f in filelist: for f in filelist:
with open(f, 'r', encoding='utf-8') as fin: with io.open(f, 'r', encoding='utf-8') as fin:
line_idx = 0 line_idx = 0
for line in fin: for line in fin:
ids = line.strip().split() ids = line.strip().split()
......
...@@ -4,10 +4,11 @@ import os ...@@ -4,10 +4,11 @@ import os
import logging import logging
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import io
def get_vocab_size(vocab_path): def get_vocab_size(vocab_path):
with open(vocab_path, "r", encoding='utf-8') as rf: with io.open(vocab_path, "r", encoding='utf-8') as rf:
line = rf.readline() line = rf.readline()
return int(line.strip()) return int(line.strip())
...@@ -30,20 +31,22 @@ def construct_test_data(file_dir, vocab_path, batch_size): ...@@ -30,20 +31,22 @@ def construct_test_data(file_dir, vocab_path, batch_size):
test_reader = fluid.io.batch(y_data.test(files), batch_size=batch_size) test_reader = fluid.io.batch(y_data.test(files), batch_size=batch_size)
return test_reader, vocab_size return test_reader, vocab_size
def check_version(): def check_version():
""" """
Log error and exit when the installed version of paddlepaddle is Log error and exit when the installed version of paddlepaddle is
not satisfied. not satisfied.
""" """
err = "PaddlePaddle version 1.6 or higher is required, " \ err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \ "or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \ "Please make sure the version is good with your code." \
try: try:
fluid.require_version('1.6.0') fluid.require_version('1.6.0')
except Exception as e: except Exception as e:
logger.error(err) logger.error(err)
sys.exit(1) sys.exit(1)
def infer_data(raw_data, place): def infer_data(raw_data, place):
data = [dat[0] for dat in raw_data] data = [dat[0] for dat in raw_data]
......
...@@ -8,9 +8,11 @@ import numpy as np ...@@ -8,9 +8,11 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import csv import csv
import io
reload(sys) reload(sys)
sys.setdefaultencoding('utf-8') sys.setdefaultencoding('utf-8')
def to_lodtensor(data, place): def to_lodtensor(data, place):
""" convert to LODtensor """ """ convert to LODtensor """
seq_lens = [len(seq) for seq in data] seq_lens = [len(seq) for seq in data]
...@@ -26,25 +28,27 @@ def to_lodtensor(data, place): ...@@ -26,25 +28,27 @@ def to_lodtensor(data, place):
res.set_lod([lod]) res.set_lod([lod])
return res return res
def get_vocab_size(vocab_path): def get_vocab_size(vocab_path):
with open(vocab_path, "r") as rf: with io.open(vocab_path, "r") as rf:
line = rf.readline() line = rf.readline()
return int(line.strip()) return int(line.strip())
def check_version(): def check_version():
""" """
Log error and exit when the installed version of paddlepaddle is Log error and exit when the installed version of paddlepaddle is
not satisfied. not satisfied.
""" """
err = "PaddlePaddle version 1.6 or higher is required, " \ err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \ "or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \ "Please make sure the version is good with your code." \
try: try:
fluid.require_version('1.6.0') fluid.require_version('1.6.0')
except Exception as e: except Exception as e:
logger.error(err) logger.error(err)
sys.exit(1) sys.exit(1)
def prepare_data(file_dir, def prepare_data(file_dir,
...@@ -62,19 +66,25 @@ def prepare_data(file_dir, ...@@ -62,19 +66,25 @@ def prepare_data(file_dir,
reader = sort_batch( reader = sort_batch(
paddle.reader.shuffle( paddle.reader.shuffle(
train( train(
file_dir, vocab_tag_size, neg_size, file_dir,
buffer_size, data_type=DataType.SEQ), vocab_tag_size,
neg_size,
buffer_size,
data_type=DataType.SEQ),
buf_size=buffer_size), buf_size=buffer_size),
batch_size, batch_size * 20) batch_size,
batch_size * 20)
else: else:
vocab_tag_size = get_vocab_size(vocab_tag_path) vocab_tag_size = get_vocab_size(vocab_tag_path)
vocab_text_size = 0 vocab_text_size = 0
reader = sort_batch( reader = sort_batch(
test( test(
file_dir, vocab_tag_size, buffer_size, data_type=DataType.SEQ), file_dir, vocab_tag_size, buffer_size, data_type=DataType.SEQ),
batch_size, batch_size * 20) batch_size,
batch_size * 20)
return vocab_text_size, vocab_tag_size, reader return vocab_text_size, vocab_tag_size, reader
def sort_batch(reader, batch_size, sort_group_size, drop_last=False): def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
""" """
Create a batched reader. Create a batched reader.
...@@ -124,11 +134,13 @@ def sort_batch(reader, batch_size, sort_group_size, drop_last=False): ...@@ -124,11 +134,13 @@ def sort_batch(reader, batch_size, sort_group_size, drop_last=False):
class DataType(object): class DataType(object):
SEQ = 2 SEQ = 2
def train_reader_creator(file_dir, tag_size, neg_size, n, data_type): def train_reader_creator(file_dir, tag_size, neg_size, n, data_type):
def reader(): def reader():
files = os.listdir(file_dir) files = os.listdir(file_dir)
for fi in files: for fi in files:
with open(os.path.join(file_dir, fi), "r", encoding='utf-8') as f: with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in f: for l in f:
l = l.strip().split(",") l = l.strip().split(",")
pos_index = int(l[0]) pos_index = int(l[0])
...@@ -140,7 +152,7 @@ def train_reader_creator(file_dir, tag_size, neg_size, n, data_type): ...@@ -140,7 +152,7 @@ def train_reader_creator(file_dir, tag_size, neg_size, n, data_type):
max_iter = 100 max_iter = 100
now_iter = 0 now_iter = 0
sum_n = 0 sum_n = 0
while(sum_n < neg_size) : while (sum_n < neg_size):
now_iter += 1 now_iter += 1
if now_iter > max_iter: if now_iter > max_iter:
print("error : only one class") print("error : only one class")
...@@ -152,13 +164,16 @@ def train_reader_creator(file_dir, tag_size, neg_size, n, data_type): ...@@ -152,13 +164,16 @@ def train_reader_creator(file_dir, tag_size, neg_size, n, data_type):
sum_n += 1 sum_n += 1
if n > 0 and len(text) > n: continue if n > 0 and len(text) > n: continue
yield text, pos_tag, neg_tag yield text, pos_tag, neg_tag
return reader return reader
def test_reader_creator(file_dir, tag_size, n, data_type): def test_reader_creator(file_dir, tag_size, n, data_type):
def reader(): def reader():
files = os.listdir(file_dir) files = os.listdir(file_dir)
for fi in files: for fi in files:
with open(os.path.join(file_dir, fi), "r", encoding='utf-8') as f: with io.open(
os.path.join(file_dir, fi), "r", encoding='utf-8') as f:
for l in f: for l in f:
l = l.strip().split(",") l = l.strip().split(",")
pos_index = int(l[0]) pos_index = int(l[0])
...@@ -170,11 +185,13 @@ def test_reader_creator(file_dir, tag_size, n, data_type): ...@@ -170,11 +185,13 @@ def test_reader_creator(file_dir, tag_size, n, data_type):
tag = [] tag = []
tag.append(ii) tag.append(ii)
yield text, tag, pos_tag yield text, tag, pos_tag
return reader return reader
def train(train_dir, tag_size, neg_size, n, data_type=DataType.SEQ): def train(train_dir, tag_size, neg_size, n, data_type=DataType.SEQ):
return train_reader_creator(train_dir, tag_size, neg_size, n, data_type) return train_reader_creator(train_dir, tag_size, neg_size, n, data_type)
def test(test_dir, tag_size, n, data_type=DataType.SEQ): def test(test_dir, tag_size, n, data_type=DataType.SEQ):
return test_reader_creator(test_dir, tag_size, n, data_type) return test_reader_creator(test_dir, tag_size, n, data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册