提交 d4e75537 编写于 作者: Q qjing666

fix code style

上级 6b1fb7cc
......@@ -2,6 +2,7 @@ import paddle.fluid as fluid
import numpy as np
import os
class Gru4rec_Reader:
def __init__(self):
pass
......@@ -21,7 +22,6 @@ class Gru4rec_Reader:
res.set_lod([lod])
return res
def lod_reader(self, reader, place):
def feed_reader():
for data in reader():
......@@ -33,12 +33,14 @@ class Gru4rec_Reader:
fe_data["src_wordseq"] = lod_src_wordseq
fe_data["dst_wordseq"] = lod_dst_wordseq
yield fe_data
return feed_reader
def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False):
"""
Create a batched reader.
"""
def batch_reader():
r = reader()
b = []
......@@ -66,11 +68,11 @@ class Gru4rec_Reader:
# Batch size check
batch_size = int(batch_size)
if batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
raise ValueError(
"batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
return batch_reader
def reader_creator(self, file_dir):
def reader():
files = os.listdir(file_dir)
......@@ -82,10 +84,12 @@ class Gru4rec_Reader:
src_seq = l[:len(l) - 1]
trg_seq = l[1:]
yield src_seq, trg_seq
return reader
def reader(self, file_dir, place, batch_size=5):
""" prepare the English Pann Treebank (PTB) data """
print("start constuct word dict")
reader = self.sort_batch(self.reader_creator(file_dir), batch_size, batch_size * 20)
reader = self.sort_batch(
self.reader_creator(file_dir), batch_size, batch_size * 20)
return self.lod_reader(reader, place)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册