Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleFL
提交
d4e75537
P
PaddleFL
项目概览
PaddlePaddle
/
PaddleFL
通知
35
Star
5
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
6
列表
看板
标记
里程碑
合并请求
4
Wiki
3
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleFL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
6
Issue
6
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
3
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4e75537
编写于
2月 27, 2020
作者:
Q
qjing666
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix code style
上级
6b1fb7cc
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
9 addition
and
5 deletion
+9
-5
paddle_fl/reader/gru4rec_reader.py
paddle_fl/reader/gru4rec_reader.py
+9
-5
未找到文件。
paddle_fl/reader/gru4rec_reader.py
浏览文件 @
d4e75537
...
@@ -2,6 +2,7 @@ import paddle.fluid as fluid
...
@@ -2,6 +2,7 @@ import paddle.fluid as fluid
import
numpy
as
np
import
numpy
as
np
import
os
import
os
class
Gru4rec_Reader
:
class
Gru4rec_Reader
:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
...
@@ -21,7 +22,6 @@ class Gru4rec_Reader:
...
@@ -21,7 +22,6 @@ class Gru4rec_Reader:
res
.
set_lod
([
lod
])
res
.
set_lod
([
lod
])
return
res
return
res
def
lod_reader
(
self
,
reader
,
place
):
def
lod_reader
(
self
,
reader
,
place
):
def
feed_reader
():
def
feed_reader
():
for
data
in
reader
():
for
data
in
reader
():
...
@@ -33,12 +33,14 @@ class Gru4rec_Reader:
...
@@ -33,12 +33,14 @@ class Gru4rec_Reader:
fe_data
[
"src_wordseq"
]
=
lod_src_wordseq
fe_data
[
"src_wordseq"
]
=
lod_src_wordseq
fe_data
[
"dst_wordseq"
]
=
lod_dst_wordseq
fe_data
[
"dst_wordseq"
]
=
lod_dst_wordseq
yield
fe_data
yield
fe_data
return
feed_reader
return
feed_reader
def
sort_batch
(
self
,
reader
,
batch_size
,
sort_group_size
,
drop_last
=
False
):
def
sort_batch
(
self
,
reader
,
batch_size
,
sort_group_size
,
drop_last
=
False
):
"""
"""
Create a batched reader.
Create a batched reader.
"""
"""
def
batch_reader
():
def
batch_reader
():
r
=
reader
()
r
=
reader
()
b
=
[]
b
=
[]
...
@@ -66,11 +68,11 @@ class Gru4rec_Reader:
...
@@ -66,11 +68,11 @@ class Gru4rec_Reader:
# Batch size check
# Batch size check
batch_size
=
int
(
batch_size
)
batch_size
=
int
(
batch_size
)
if
batch_size
<=
0
:
if
batch_size
<=
0
:
raise
ValueError
(
"batch_size should be a positive integeral value, "
raise
ValueError
(
"batch_size should be a positive integeral value, "
"but got batch_size={}"
.
format
(
batch_size
))
"but got batch_size={}"
.
format
(
batch_size
))
return
batch_reader
return
batch_reader
def
reader_creator
(
self
,
file_dir
):
def
reader_creator
(
self
,
file_dir
):
def
reader
():
def
reader
():
files
=
os
.
listdir
(
file_dir
)
files
=
os
.
listdir
(
file_dir
)
...
@@ -82,10 +84,12 @@ class Gru4rec_Reader:
...
@@ -82,10 +84,12 @@ class Gru4rec_Reader:
src_seq
=
l
[:
len
(
l
)
-
1
]
src_seq
=
l
[:
len
(
l
)
-
1
]
trg_seq
=
l
[
1
:]
trg_seq
=
l
[
1
:]
yield
src_seq
,
trg_seq
yield
src_seq
,
trg_seq
return
reader
return
reader
def
reader
(
self
,
file_dir
,
place
,
batch_size
=
5
):
def
reader
(
self
,
file_dir
,
place
,
batch_size
=
5
):
""" prepare the English Pann Treebank (PTB) data """
""" prepare the English Pann Treebank (PTB) data """
print
(
"start constuct word dict"
)
print
(
"start constuct word dict"
)
reader
=
self
.
sort_batch
(
self
.
reader_creator
(
file_dir
),
batch_size
,
batch_size
*
20
)
reader
=
self
.
sort_batch
(
self
.
reader_creator
(
file_dir
),
batch_size
,
batch_size
*
20
)
return
self
.
lod_reader
(
reader
,
place
)
return
self
.
lod_reader
(
reader
,
place
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录