提交 d4e75537 编写于 作者: Q qjing666

fix code style

上级 6b1fb7cc
...@@ -2,6 +2,7 @@ import paddle.fluid as fluid ...@@ -2,6 +2,7 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import os import os
class Gru4rec_Reader: class Gru4rec_Reader:
def __init__(self): def __init__(self):
pass pass
...@@ -21,7 +22,6 @@ class Gru4rec_Reader: ...@@ -21,7 +22,6 @@ class Gru4rec_Reader:
res.set_lod([lod]) res.set_lod([lod])
return res return res
def lod_reader(self, reader, place): def lod_reader(self, reader, place):
def feed_reader(): def feed_reader():
for data in reader(): for data in reader():
...@@ -33,12 +33,14 @@ class Gru4rec_Reader: ...@@ -33,12 +33,14 @@ class Gru4rec_Reader:
fe_data["src_wordseq"] = lod_src_wordseq fe_data["src_wordseq"] = lod_src_wordseq
fe_data["dst_wordseq"] = lod_dst_wordseq fe_data["dst_wordseq"] = lod_dst_wordseq
yield fe_data yield fe_data
return feed_reader return feed_reader
def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False): def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False):
""" """
Create a batched reader. Create a batched reader.
""" """
def batch_reader(): def batch_reader():
r = reader() r = reader()
b = [] b = []
...@@ -66,11 +68,11 @@ class Gru4rec_Reader: ...@@ -66,11 +68,11 @@ class Gru4rec_Reader:
# Batch size check # Batch size check
batch_size = int(batch_size) batch_size = int(batch_size)
if batch_size <= 0: if batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, " raise ValueError(
"batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size)) "but got batch_size={}".format(batch_size))
return batch_reader return batch_reader
def reader_creator(self, file_dir): def reader_creator(self, file_dir):
def reader(): def reader():
files = os.listdir(file_dir) files = os.listdir(file_dir)
...@@ -82,10 +84,12 @@ class Gru4rec_Reader: ...@@ -82,10 +84,12 @@ class Gru4rec_Reader:
src_seq = l[:len(l) - 1] src_seq = l[:len(l) - 1]
trg_seq = l[1:] trg_seq = l[1:]
yield src_seq, trg_seq yield src_seq, trg_seq
return reader return reader
def reader(self, file_dir, place, batch_size=5): def reader(self, file_dir, place, batch_size=5):
""" prepare the English Pann Treebank (PTB) data """ """ prepare the English Pann Treebank (PTB) data """
print("start constuct word dict") print("start constuct word dict")
reader = self.sort_batch(self.reader_creator(file_dir), batch_size, batch_size * 20) reader = self.sort_batch(
self.reader_creator(file_dir), batch_size, batch_size * 20)
return self.lod_reader(reader, place) return self.lod_reader(reader, place)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册