未验证 提交 8e66f7f8 编写于 作者: L lvmengsi 提交者: GitHub

Fix fit a line reader (#789)

* fix reader
上级 3477bfac
...@@ -173,18 +173,20 @@ offset = int(data.shape[0]*ratio) ...@@ -173,18 +173,20 @@ offset = int(data.shape[0]*ratio)
train_data = data[:offset] train_data = data[:offset]
test_data = data[offset:] test_data = data[offset:]
def reader(data): def reader_creator(train_data):
def reader():
for d in train_data: for d in train_data:
yield d[:1], d[-1:] yield d[:-1], d[-1:]
return reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader(train_data), buf_size=500), reader_creator(train_data), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader(test_data), buf_size=500), reader_creator(test_data), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
``` ```
......
...@@ -215,18 +215,20 @@ offset = int(data.shape[0]*ratio) ...@@ -215,18 +215,20 @@ offset = int(data.shape[0]*ratio)
train_data = data[:offset] train_data = data[:offset]
test_data = data[offset:] test_data = data[offset:]
def reader(data): def reader_creator(train_data):
def reader():
for d in train_data: for d in train_data:
yield d[:1], d[-1:] yield d[:-1], d[-1:]
return reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader(train_data), buf_size=500), reader_creator(train_data), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader(test_data), buf_size=500), reader_creator(test_data), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册