Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cd8700f1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
cd8700f1
编写于
5月 25, 2018
作者:
F
fengjiayi
提交者:
GitHub
5月 25, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10872 from JiayiFeng/dev_CustomReader
CustomReader
上级
75303664
81470635
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
365 addition
and
5 deletion
+365
-5
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+1
-2
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+1
-0
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+187
-0
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+1
-0
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+82
-3
python/paddle/fluid/tests/unittests/test_preprocessor.py
python/paddle/fluid/tests/unittests/test_preprocessor.py
+93
-0
未找到文件。
paddle/fluid/framework/shape_inference.h
浏览文件 @
cd8700f1
...
@@ -63,6 +63,7 @@ class InferShapeContext {
...
@@ -63,6 +63,7 @@ class InferShapeContext {
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
);
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
);
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
);
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
);
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
// Note: In while op, we need this to be public
// Note: In while op, we need this to be public
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
...
@@ -81,8 +82,6 @@ class InferShapeContext {
...
@@ -81,8 +82,6 @@ class InferShapeContext {
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
cd8700f1
...
@@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
...
@@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
cc_test
(
reader_blocking_queue_test SRCS reader_blocking_queue_test.cc
)
cc_test
(
reader_blocking_queue_test SRCS reader_blocking_queue_test.cc
)
# Export local libraries to parent
# Export local libraries to parent
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
0 → 100644
浏览文件 @
cd8700f1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
CustomReader
:
public
framework
::
DecoratedReader
{
public:
CustomReader
(
ReaderBase
*
reader
,
const
framework
::
BlockDesc
&
sub_block
,
const
platform
::
Place
&
dev_place
,
const
std
::
vector
<
std
::
string
>&
source_var_names
,
const
std
::
vector
<
std
::
string
>&
sink_var_names
)
:
DecoratedReader
(
reader
),
program_
(
*
sub_block
.
Program
()),
sub_block_id_
(
sub_block
.
ID
()),
exe_
(
framework
::
Executor
(
dev_place
)),
source_var_names_
(
source_var_names
),
sink_var_names_
(
sink_var_names
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
const
framework
::
ProgramDesc
program_
;
int
sub_block_id_
;
framework
::
Executor
exe_
;
std
::
vector
<
std
::
string
>
source_var_names_
;
std
::
vector
<
std
::
string
>
sink_var_names_
;
};
class
CreateCustomReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
auto
*
sub_block
=
Attr
<
framework
::
BlockDesc
*>
(
"sub_block"
);
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
CustomReader
(
underlying_reader
.
Get
(),
*
sub_block
,
dev_place
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
}
};
class
CreateCustomReaderOpMaker
:
public
DecoratedReaderMakerBase
{
protected:
void
Apply
()
override
{
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
"The block to hold all preprocessing operators."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
,
"Source variables are starting points of data preprocessing. They hold "
"preprocessing's input tensors. Each source variable corresponds to "
"one of underlying reader's output datas."
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
,
"Sink variables are ending points of data preprocessing. They hold "
"preprocessing's output tensors. Each sink variable corresponds to "
"one of custom reader's output datas."
);
AddComment
(
R"DOC(
CreateCustomReader Operator
A custom reader can be used for input data preprocessing.
A custom reader holds its own sub-block, which will be executed in its
'ReadNext()' function. Users can configurate their own preprocessing
pipelines by inserting operators into custom reader's sub-block.
)DOC"
);
}
};
class
CustomReaderInferShape
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
!
ctx
->
IsRuntime
(),
"'CustomReaderInferShape' should only be invoked during "
"compile time."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"The output decorated reader should not be null."
);
const
auto
*
sub_block
=
ctx
->
Attrs
().
Get
<
framework
::
BlockDesc
*>
(
"sub_block"
);
const
auto
sink_var_names
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
);
std
::
vector
<
std
::
vector
<
int64_t
>>
res_dims
;
std
::
vector
<
int32_t
>
res_lod_levels
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
auto
*
sink_var
=
sub_block
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
sink_var
);
res_dims
.
emplace_back
(
sink_var
->
GetShape
());
res_lod_levels
.
push_back
(
sink_var
->
GetLoDLevel
());
}
auto
*
out_reader
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetShapes
(
res_dims
);
out_reader
->
SetLoDLevels
(
res_lod_levels
);
}
};
class
CustomReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
VarDesc
*
out_reader
=
block
->
FindVar
(
op_desc
.
Output
(
"Out"
)[
0
]);
PADDLE_ENFORCE_NOT_NULL
(
out_reader
);
out_reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
auto
sink_var_names
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
op_desc
.
GetAttr
(
"sink_var_names"
));
const
auto
*
sub_block
=
boost
::
get
<
framework
::
BlockDesc
*>
(
op_desc
.
GetAttr
(
"sub_block"
));
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
res_data_types
;
for
(
const
std
::
string
&
var_name
:
sink_var_names
)
{
framework
::
VarDesc
*
var
=
sub_block
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
res_data_types
.
emplace_back
(
var
->
GetDataType
());
}
out_reader
->
SetDataTypes
(
res_data_types
);
}
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
reader_
->
ReadNext
(
&
underlying_outs
);
if
(
underlying_outs
.
empty
())
{
// There is not next data.
return
;
}
PADDLE_ENFORCE
(
source_var_names_
.
size
()
==
underlying_outs
.
size
(),
"The size of source_var_names(%d) and the size of "
"underlying_outs(%d) are not consistent. Each feeding element "
"must have its own source variable."
,
source_var_names_
.
size
(),
underlying_outs
.
size
());
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
framework
::
Scope
scope
;
// 1. Copy LoDTensors from underlying reader's output to source variables.
for
(
size_t
i
=
0
;
i
<
source_var_names_
.
size
();
++
i
)
{
framework
::
Variable
*
var
=
scope
.
Var
(
source_var_names_
[
i
]);
framework
::
LoDTensor
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
ShareDataWith
(
underlying_outs
[
i
]);
tensor
->
set_lod
(
underlying_outs
[
i
].
lod
());
}
// 2. Run the sub-block.
exe_
.
Run
(
program_
,
&
scope
,
sub_block_id_
,
false
,
true
);
// 3. Copy LoDTensors from sink variables to out.
out
->
resize
(
sink_var_names_
.
size
());
for
(
size_t
i
=
0
;
i
<
sink_var_names_
.
size
();
++
i
)
{
const
auto
&
tensor
=
detail
::
Ref
(
scope
.
FindVar
(
sink_var_names_
[
i
]))
.
Get
<
framework
::
LoDTensor
>
();
framework
::
TensorCopySync
(
tensor
,
platform
::
CPUPlace
(),
&
(
*
out
)[
i
]);
}
}
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
::
reader
;
REGISTER_OPERATOR
(
create_custom_reader
,
ops
::
CreateCustomReaderOp
,
ops
::
CreateCustomReaderOpMaker
,
ops
::
CustomReaderInferShape
,
ops
::
CustomReaderInferVarType
,
paddle
::
framework
::
EmptyGradOpMaker
)
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
cd8700f1
...
@@ -115,6 +115,7 @@ void DecoratedReaderInferShape::operator()(
...
@@ -115,6 +115,7 @@ void DecoratedReaderInferShape::operator()(
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetOutputVarPtrs
(
"Out"
)[
0
]);
out_reader
->
SetLoDLevels
(
in_reader
->
GetLoDLevels
());
out_reader
->
SetLoDLevels
(
in_reader
->
GetLoDLevels
());
}
}
void
DecoratedReaderInferVarType
::
operator
()(
void
DecoratedReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
{
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
{
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
cd8700f1
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
contextlib
from
..
import
core
from
..
import
core
from
..framework
import
convert_np_dtype_to_dtype_
,
default_main_program
,
default_startup_program
,
Program
from
..framework
import
convert_np_dtype_to_dtype_
,
default_main_program
,
default_startup_program
,
Program
...
@@ -21,7 +22,8 @@ from ..executor import global_scope
...
@@ -21,7 +22,8 @@ from ..executor import global_scope
__all__
=
[
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'open_files'
,
'read_file'
,
'shuffle'
,
'batch'
,
'double_buffer'
'open_files'
,
'read_file'
,
'shuffle'
,
'batch'
,
'double_buffer'
,
'Preprocessor'
]
]
...
@@ -535,8 +537,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
...
@@ -535,8 +537,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
inputs
=
{
'UnderlyingReader'
:
reader
},
inputs
=
{
'UnderlyingReader'
:
reader
},
outputs
=
{
'Out'
:
[
new_reader
]},
outputs
=
{
'Out'
:
[
new_reader
]},
attrs
=
attrs
)
attrs
=
attrs
)
new_reader
.
persistable
=
True
new_reader
.
stop_gradient
=
True
return
monkey_patch_reader_methods
(
new_reader
)
return
monkey_patch_reader_methods
(
new_reader
)
...
@@ -581,3 +581,82 @@ def read_file(file_obj):
...
@@ -581,3 +581,82 @@ def read_file(file_obj):
return
out
[
0
]
return
out
[
0
]
else
:
else
:
return
out
return
out
class
Preprocessor
(
object
):
BEFORE_SUB_BLOCK
=
0
IN_SUB_BLOCK
=
1
AFTER_SUB_BLOCK
=
2
def
__init__
(
self
,
reader
,
name
=
None
):
self
.
underlying_reader
=
reader
new_reader_name
=
name
if
name
is
not
None
else
unique_name
(
"create_custom_reader"
)
self
.
main_prog
=
default_main_program
()
self
.
reader
=
self
.
main_prog
.
current_block
().
create_var
(
name
=
new_reader_name
)
self
.
sub_block
=
None
self
.
source_var_names
=
None
self
.
sink_var_names
=
None
self
.
status
=
Preprocessor
.
BEFORE_SUB_BLOCK
def
is_completed
(
self
):
return
self
.
sub_block
and
self
.
source_var_names
and
self
.
sink_var_names
@
contextlib
.
contextmanager
def
block
(
self
):
self
.
status
=
Preprocessor
.
IN_SUB_BLOCK
self
.
sub_block
=
self
.
main_prog
.
create_block
()
yield
self
.
main_prog
.
rollback
()
self
.
status
=
Preprocessor
.
AFTER_SUB_BLOCK
if
not
self
.
is_completed
():
raise
RuntimeError
(
"The definition of preprocessor is incompleted! "
"Please make sure that you have set input and output "
"variables by invoking 'inputs' and 'outputs' in "
"Preprocessor's sub-block."
)
def
inputs
(
self
):
if
self
.
status
!=
Preprocessor
.
IN_SUB_BLOCK
:
raise
RuntimeError
(
"Preprocessor.inputs() can only be invoked inside the sub-block."
)
source_shapes
=
self
.
underlying_reader
.
desc
.
shapes
()
source_dtypes
=
self
.
underlying_reader
.
desc
.
dtypes
()
source_lod_levels
=
self
.
underlying_reader
.
desc
.
lod_levels
()
self
.
source_var_names
=
[
unique_name
(
"preprocessor_source"
)
for
_
in
xrange
(
len
(
source_shapes
))
]
source_vars
=
[]
for
var_name
,
shape
,
dtype
,
lod_level
in
zip
(
self
.
source_var_names
,
source_shapes
,
source_dtypes
,
source_lod_levels
):
source_vars
.
append
(
self
.
main_prog
.
current_block
().
create_var
(
name
=
var_name
,
shape
=
shape
,
dtype
=
dtype
,
lod_level
=
lod_level
))
return
source_vars
def
outputs
(
self
,
*
outs
):
if
self
.
status
!=
Preprocessor
.
IN_SUB_BLOCK
:
raise
RuntimeError
(
"Preprocessor.outputs() can only be invoked inside the sub-block."
)
self
.
sink_var_names
=
[
var
.
name
for
var
in
outs
]
def
__call__
(
self
,
*
args
,
**
kwargs
):
if
self
.
status
!=
Preprocessor
.
AFTER_SUB_BLOCK
:
raise
RuntimeError
(
"Preprocessor output can only be retrieved after rnn block."
)
self
.
main_prog
.
current_block
().
append_op
(
type
=
"create_custom_reader"
,
inputs
=
{
'UnderlyingReader'
:
self
.
underlying_reader
},
outputs
=
{
'Out'
:
[
self
.
reader
]},
attrs
=
{
"sub_block"
:
self
.
sub_block
,
"source_var_names"
:
self
.
source_var_names
,
"sink_var_names"
:
self
.
sink_var_names
})
return
monkey_patch_reader_methods
(
self
.
reader
)
python/paddle/fluid/tests/unittests/test_preprocessor.py
0 → 100644
浏览文件 @
cd8700f1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.v2
as
paddle
import
paddle.v2.dataset.mnist
as
mnist
class
TestPreprocessor
(
unittest
.
TestCase
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
32
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
# order is image and label
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
]),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
self
.
num_batches
=
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist_for_preprocessor_test.recordio'
,
reader
,
feeder
)
def
test_main
(
self
):
N
=
10
img_expected_res
=
[]
lbl_expected_res
=
[]
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data_file
=
fluid
.
layers
.
io
.
open_recordio_file
(
'./mnist_for_preprocessor_test.recordio'
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
img
,
lbl
=
fluid
.
layers
.
io
.
read_file
(
data_file
)
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
_
in
range
(
N
):
img_v
,
lbl_v
=
exe
.
run
(
fetch_list
=
[
img
,
lbl
])
img_expected_res
.
append
(
img_v
/
2
)
lbl_expected_res
.
append
(
lbl_v
+
1
)
img_actual_res
=
[]
lbl_actual_res
=
[]
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data_file
=
fluid
.
layers
.
io
.
open_recordio_file
(
'./mnist_for_preprocessor_test.recordio'
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
preprocessor
=
fluid
.
layers
.
io
.
Preprocessor
(
reader
=
data_file
)
with
preprocessor
.
block
():
img
,
lbl
=
preprocessor
.
inputs
()
img_out
=
img
/
2
lbl_out
=
lbl
+
1
preprocessor
.
outputs
(
img_out
,
lbl_out
)
data_file
=
fluid
.
layers
.
io
.
double_buffer
(
preprocessor
())
img
,
lbl
=
fluid
.
layers
.
io
.
read_file
(
data_file
)
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
_
in
range
(
N
):
img_v
,
lbl_v
=
exe
.
run
(
fetch_list
=
[
img
,
lbl
])
img_actual_res
.
append
(
img_v
)
lbl_actual_res
.
append
(
lbl_v
)
for
idx
in
range
(
N
):
np
.
allclose
(
img_expected_res
[
idx
],
img_actual_res
[
idx
])
np
.
allclose
(
lbl_expected_res
[
idx
],
lbl_actual_res
[
idx
])
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录