Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
44d5f42a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
44d5f42a
编写于
4月 03, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update reader
上级
a4e437d5
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
64 addition
and
18 deletion
+64
-18
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+5
-2
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+5
-3
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+6
-3
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+6
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+15
-0
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+27
-8
未找到文件。
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
44d5f42a
...
...
@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
)));
}
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
44d5f42a
...
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
...
...
@@ -98,10 +97,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
platform
::
Place
place
;
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
44d5f42a
...
...
@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
&
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)));
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
out
.
GetMutable
<
framework
::
ReaderHolder
>
()
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
out
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
}
};
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
44d5f42a
...
...
@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
auto
&
var
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)));
var
.
GetMutable
<
framework
::
ReaderHolder
>
()
->
Reset
(
out
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
}
...
...
python/paddle/fluid/framework.py
浏览文件 @
44d5f42a
...
...
@@ -640,6 +640,21 @@ class Operator(object):
"""
return
self
.
desc
.
block_attr
(
name
)
@
property
def
attrs
(
self
):
"""
Get the attribute dict
Returns(dict): The Operator's attribute dict
"""
attr_names
=
self
.
attr_names
attr_map
=
{}
for
n
in
attr_names
:
if
n
==
'sub_block'
:
attr_map
[
n
]
=
self
.
block_attr
(
n
)
else
:
attr_map
[
n
]
=
self
.
attr
(
n
)
return
attr_map
class
Block
(
object
):
def
__init__
(
self
,
program
,
idx
):
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
44d5f42a
...
...
@@ -255,7 +255,22 @@ def _copy_reader_var_(block, var):
new_var
.
desc
.
set_shapes
(
var
.
desc
.
shapes
())
new_var
.
desc
.
set_dtypes
(
var
.
desc
.
dtypes
())
new_var
.
persistable
=
True
return
monkey_patch_reader_methods
(
new_var
)
return
new_var
def
_copy_reader_create_op_
(
block
,
op
):
def
_find_vars_
(
block
,
name_list
):
res
=
{}
for
n
in
name_list
:
var
=
block
.
var
(
n
)
res
[
n
]
=
var
return
res
input_map
=
_find_vars_
(
block
,
op
.
input_names
)
output_map
=
_find_vars_
(
block
,
op
.
output_names
)
new_op
=
block
.
append_op
(
type
=
op
.
type
,
inputs
=
input_map
,
outputs
=
output_map
,
attrs
=
op
.
attrs
)
return
new_op
def
open_recordio_file
(
filename
,
shapes
,
lod_levels
,
dtypes
):
...
...
@@ -283,8 +298,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var
.
desc
.
set_dtypes
(
dtypes
)
startup_var
.
persistable
=
True
return
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
def
open_files
(
filenames
,
thread_num
,
shapes
,
lod_levels
,
dtypes
):
...
...
@@ -313,22 +329,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
startup_var
.
desc
.
set_dtypes
(
dtypes
)
startup_var
.
persistable
=
True
return
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
def
__create_decorated_reader__
(
op_type
,
reader
,
attrs
):
var_name
=
unique_name
(
op_type
)
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
startup_blk
.
append_op
(
start
op_op
=
start
up_blk
.
append_op
(
type
=
op_type
,
inputs
=
{
'UnderlyingReader'
:
reader
},
outputs
=
{
'Out'
:
[
startup_var
]},
attrs
=
attrs
)
startup_var
.
persistable
=
True
return
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
main_prog_block
=
default_main_program
().
current_block
()
main_prog_var
=
_copy_reader_var_
(
main_prog_block
,
startup_var
)
_copy_reader_create_op_
(
main_prog_block
,
startop_op
)
return
monkey_patch_reader_methods
(
main_prog_var
)
def
create_shuffle_reader
(
reader
,
buffer_size
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录