Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b48eba19
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b48eba19
编写于
5月 23, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
complete python API and unit test
上级
983c9a2a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
156 addition
and
11 deletion
+156
-11
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+9
-8
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+81
-3
python/paddle/fluid/tests/unittests/test_preprocessor.py
python/paddle/fluid/tests/unittests/test_preprocessor.py
+66
-0
未找到文件。
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
b48eba19
...
@@ -65,9 +65,8 @@ class CreateCustomReaderOp : public framework::OperatorBase {
...
@@ -65,9 +65,8 @@ class CreateCustomReaderOp : public framework::OperatorBase {
};
};
class
CreateCustomReaderOpMaker
:
public
DecoratedReaderMakerBase
{
class
CreateCustomReaderOpMaker
:
public
DecoratedReaderMakerBase
{
public:
protected:
CreateCustomReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
void
Apply
()
override
{
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
""
);
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
""
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
,
""
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
,
""
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
,
""
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
,
""
);
...
@@ -86,13 +85,14 @@ class CustomReaderInferShape : public framework::InferShapeBase {
...
@@ -86,13 +85,14 @@ class CustomReaderInferShape : public framework::InferShapeBase {
"compile time."
);
"compile time."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
"The output decorated reader should not be null."
);
const
auto
*
sub_block
=
ctx
->
Attrs
().
Get
<
framework
::
BlockDesc
*>
(
"sub_block"
);
const
auto
sink_var_names
=
const
auto
sink_var_names
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
);
ctx
->
Attrs
().
Get
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
);
std
::
vector
<
std
::
vector
<
int64_t
>>
res_dims
;
std
::
vector
<
std
::
vector
<
int64_t
>>
res_dims
;
std
::
vector
<
int32_t
>
res_lod_levels
;
std
::
vector
<
int32_t
>
res_lod_levels
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
auto
*
sink_var
=
auto
*
sink_var
=
sub_block
->
FindVar
(
var_name
);
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetVarPtr
(
var_name
));
PADDLE_ENFORCE_NOT_NULL
(
sink_var
);
PADDLE_ENFORCE_NOT_NULL
(
sink_var
);
res_dims
.
emplace_back
(
sink_var
->
GetShape
());
res_dims
.
emplace_back
(
sink_var
->
GetShape
());
res_lod_levels
.
push_back
(
sink_var
->
GetLoDLevel
());
res_lod_levels
.
push_back
(
sink_var
->
GetLoDLevel
());
...
@@ -114,9 +114,11 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
...
@@ -114,9 +114,11 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
auto
sink_var_names
=
auto
sink_var_names
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
.
GetAttr
(
"sink_var_names"
));
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
.
GetAttr
(
"sink_var_names"
));
const
auto
*
sub_block
=
boost
::
get
<
framework
::
BlockDesc
*>
(
op_desc
.
GetAttr
(
"sub_block"
));
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
res_data_types
;
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
res_data_types
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
framework
::
VarDesc
*
var
=
block
->
FindVar
(
var_name
);
framework
::
VarDesc
*
var
=
sub_
block
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
res_data_types
.
emplace_back
(
var
->
GetDataType
());
res_data_types
.
emplace_back
(
var
->
GetDataType
());
}
}
...
@@ -152,8 +154,7 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
...
@@ -152,8 +154,7 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
framework
::
Executor
executor
(
dev_place_
);
framework
::
Executor
executor
(
dev_place_
);
framework
::
ProgramDesc
*
program
=
sub_block_
.
Program
();
framework
::
ProgramDesc
*
program
=
sub_block_
.
Program
();
framework
::
Scope
*
exe_scope
=
&
scope_
.
NewScope
();
framework
::
Scope
*
exe_scope
=
&
scope_
.
NewScope
();
executor
.
Run
(
*
program
,
exe_scope
,
sub_block_
.
ID
(),
executor
.
Run
(
*
program
,
exe_scope
,
sub_block_
.
ID
(),
false
,
true
);
false
/*create_local_scope*/
,
true
);
scope_
.
DeleteScope
(
exe_scope
);
scope_
.
DeleteScope
(
exe_scope
);
// 3. Copy LoDTensors from sink variables to out.
// 3. Copy LoDTensors from sink variables to out.
out
->
resize
(
sink_var_names_
.
size
());
out
->
resize
(
sink_var_names_
.
size
());
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
b48eba19
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
contextlib
from
..
import
core
from
..
import
core
from
..framework
import
convert_np_dtype_to_dtype_
,
default_main_program
,
default_startup_program
,
Program
from
..framework
import
convert_np_dtype_to_dtype_
,
default_main_program
,
default_startup_program
,
Program
...
@@ -21,7 +22,8 @@ from ..executor import global_scope
...
@@ -21,7 +22,8 @@ from ..executor import global_scope
__all__
=
[
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'open_files'
,
'read_file'
,
'shuffle'
,
'batch'
,
'double_buffer'
'open_files'
,
'read_file'
,
'shuffle'
,
'batch'
,
'double_buffer'
,
'Preprocessor'
]
]
...
@@ -468,8 +470,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
...
@@ -468,8 +470,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
inputs
=
{
'UnderlyingReader'
:
reader
},
inputs
=
{
'UnderlyingReader'
:
reader
},
outputs
=
{
'Out'
:
[
new_reader
]},
outputs
=
{
'Out'
:
[
new_reader
]},
attrs
=
attrs
)
attrs
=
attrs
)
new_reader
.
persistable
=
True
new_reader
.
stop_gradient
=
True
return
monkey_patch_reader_methods
(
new_reader
)
return
monkey_patch_reader_methods
(
new_reader
)
...
@@ -514,3 +514,81 @@ def read_file(file_obj):
...
@@ -514,3 +514,81 @@ def read_file(file_obj):
return
out
[
0
]
return
out
[
0
]
else
:
else
:
return
out
return
out
class
Preprocessor
(
object
):
BEFORE_SUB_BLOCK
=
0
IN_SUB_BLOCK
=
1
AFTER_SUB_BLOCK
=
2
def
__init__
(
self
,
reader
,
name
=
None
):
self
.
underlying_reader
=
reader
new_reader_name
=
name
if
name
is
not
None
else
unique_name
(
"create_custom_reader"
)
self
.
main_prog
=
default_main_program
()
self
.
reader
=
self
.
main_prog
.
current_block
().
create_var
(
name
=
new_reader_name
)
self
.
sub_block
=
None
self
.
source_var_names
=
None
self
.
sink_var_names
=
None
self
.
status
=
Preprocessor
.
BEFORE_SUB_BLOCK
def
is_completed
(
self
):
return
self
.
sub_block
and
self
.
source_var_names
and
self
.
sink_var_names
@
contextlib
.
contextmanager
def
block
(
self
):
self
.
status
=
Preprocessor
.
IN_SUB_BLOCK
self
.
sub_block
=
self
.
main_prog
.
create_block
()
yield
self
.
main_prog
.
rollback
()
self
.
status
=
Preprocessor
.
AFTER_SUB_BLOCK
if
not
self
.
is_completed
():
raise
RuntimeError
(
"The definition of preprocessor is incompleted! "
"Please make sure that you have set input and output "
"variables by invoking 'inputs' and 'outputs' in "
"Preprocessor's sub-block."
)
def
inputs
(
self
):
if
self
.
status
!=
Preprocessor
.
IN_SUB_BLOCK
:
raise
RuntimeError
(
"Preprocessor.inputs() can only be invoked inside the sub-block."
)
source_shapes
=
self
.
underlying_reader
.
desc
.
shapes
()
source_dtypes
=
self
.
underlying_reader
.
desc
.
dtypes
()
source_lod_levels
=
self
.
underlying_reader
.
desc
.
lod_levels
()
self
.
source_var_names
=
[]
source_vars
=
[]
for
idx
in
xrange
(
len
(
source_shapes
)):
self
.
source_var_names
.
append
(
unique_name
(
"preprocessor_source"
))
source_vars
.
append
(
self
.
main_prog
.
current_block
().
create_var
(
name
=
self
.
source_var_names
[
-
1
],
shape
=
source_shapes
[
idx
],
dtype
=
source_dtypes
[
idx
],
lod_level
=
source_lod_levels
[
idx
]))
return
source_vars
def
outputs
(
self
,
*
outs
):
if
self
.
status
!=
Preprocessor
.
IN_SUB_BLOCK
:
raise
RuntimeError
(
"Preprocessor.outputs() can only be invoked inside the sub-block."
)
self
.
sink_var_names
=
[
var
.
name
for
var
in
outs
]
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
self
.
status
!=
Preprocessor
.
AFTER_SUB_BLOCK
:
raise
RuntimeError
(
"Preprocessor output can only be retrieved after rnn block."
)
self
.
main_prog
.
current_block
().
append_op
(
type
=
"create_custom_reader"
,
inputs
=
{
'UnderlyingReader'
:
self
.
underlying_reader
},
outputs
=
{
'Out'
:
[
self
.
reader
]},
attrs
=
{
"sub_block"
:
self
.
sub_block
,
"source_var_names"
:
self
.
source_var_names
,
"sink_var_names"
:
self
.
sink_var_names
})
return
monkey_patch_reader_methods
(
self
.
reader
)
python/paddle/fluid/tests/unittests/test_preprocessor.py
0 → 100644
浏览文件 @
b48eba19
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.v2
as
paddle
import
paddle.v2.dataset.mnist
as
mnist
class
TestPreprocessor
(
unittest
.
TestCase
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
32
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
# order is image and label
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
]),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
self
.
num_batches
=
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist_for_preprocessor_test.recordio'
,
reader
,
feeder
)
def
test_main
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data_file
=
fluid
.
layers
.
io
.
open_recordio_file
(
'./mnist_for_preprocessor_test.recordio'
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
preprocessor
=
fluid
.
layers
.
io
.
Preprocessor
(
reader
=
data_file
)
with
preprocessor
.
block
():
img
,
lbl
=
preprocessor
.
inputs
()
img_out
=
img
/
2
lbl_out
=
lbl
+
1
preprocessor
.
outputs
(
img_out
,
lbl_out
)
img_before
,
lbl_before
=
fluid
.
layers
.
io
.
read_file
(
data_file
)
img_after
,
lbl_after
=
fluid
.
layers
.
io
.
read_file
(
preprocessor
())
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
_
in
range
(
5
):
img_b
,
lbl_b
,
img_a
,
lbl_a
=
exe
.
run
(
fetch_list
=
[
img_before
,
lbl_before
,
img_after
,
lbl_after
])
self
.
assertEqual
(
img_b
/
2
,
img_a
)
self
.
assertEqual
(
lbl_b
+
1
,
lbl_a
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录