From d4e7553729cdca24599b70cda41c22bbcbfeece7 Mon Sep 17 00:00:00 2001 From: qjing666 Date: Thu, 27 Feb 2020 18:51:25 +0800 Subject: [PATCH] fix code style --- paddle_fl/reader/gru4rec_reader.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/paddle_fl/reader/gru4rec_reader.py b/paddle_fl/reader/gru4rec_reader.py index dadb731..0307056 100644 --- a/paddle_fl/reader/gru4rec_reader.py +++ b/paddle_fl/reader/gru4rec_reader.py @@ -2,6 +2,7 @@ import paddle.fluid as fluid import numpy as np import os + class Gru4rec_Reader: def __init__(self): pass @@ -21,7 +22,6 @@ class Gru4rec_Reader: res.set_lod([lod]) return res - def lod_reader(self, reader, place): def feed_reader(): for data in reader(): @@ -33,12 +33,14 @@ class Gru4rec_Reader: fe_data["src_wordseq"] = lod_src_wordseq fe_data["dst_wordseq"] = lod_dst_wordseq yield fe_data + return feed_reader def sort_batch(self, reader, batch_size, sort_group_size, drop_last=False): """ Create a batched reader. """ + def batch_reader(): r = reader() b = [] @@ -66,11 +68,11 @@ class Gru4rec_Reader: # Batch size check batch_size = int(batch_size) if batch_size <= 0: - raise ValueError("batch_size should be a positive integeral value, " - "but got batch_size={}".format(batch_size)) + raise ValueError( + "batch_size should be a positive integeral value, " + "but got batch_size={}".format(batch_size)) return batch_reader - def reader_creator(self, file_dir): def reader(): files = os.listdir(file_dir) @@ -82,10 +84,12 @@ class Gru4rec_Reader: src_seq = l[:len(l) - 1] trg_seq = l[1:] yield src_seq, trg_seq + return reader def reader(self, file_dir, place, batch_size=5): """ prepare the English Pann Treebank (PTB) data """ print("start constuct word dict") - reader = self.sort_batch(self.reader_creator(file_dir), batch_size, batch_size * 20) + reader = self.sort_batch( + self.reader_creator(file_dir), batch_size, batch_size * 20) return self.lod_reader(reader, place) -- GitLab