提交 b48eba19 编写于 作者: F fengjiayi

complete python API and unit test

上级 983c9a2a
...@@ -65,9 +65,8 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -65,9 +65,8 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}; };
class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
public: protected:
CreateCustomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) void Apply() override {
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<framework::BlockDesc*>("sub_block", ""); AddAttr<framework::BlockDesc*>("sub_block", "");
AddAttr<std::vector<std::string>>("source_var_names", ""); AddAttr<std::vector<std::string>>("source_var_names", "");
AddAttr<std::vector<std::string>>("sink_var_names", ""); AddAttr<std::vector<std::string>>("sink_var_names", "");
...@@ -86,13 +85,14 @@ class CustomReaderInferShape : public framework::InferShapeBase { ...@@ -86,13 +85,14 @@ class CustomReaderInferShape : public framework::InferShapeBase {
"compile time."); "compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null."); "The output decorated reader should not be null.");
const auto* sub_block =
ctx->Attrs().Get<framework::BlockDesc*>("sub_block");
const auto sink_var_names = const auto sink_var_names =
ctx->Attrs().Get<std::vector<std::string>>("sink_var_names"); ctx->Attrs().Get<std::vector<std::string>>("sink_var_names");
std::vector<std::vector<int64_t>> res_dims; std::vector<std::vector<int64_t>> res_dims;
std::vector<int32_t> res_lod_levels; std::vector<int32_t> res_lod_levels;
for (const std::string& var_name : sink_var_names) { for (const std::string& var_name : sink_var_names) {
auto* sink_var = auto* sink_var = sub_block->FindVar(var_name);
boost::get<framework::VarDesc*>(ctx->GetVarPtr(var_name));
PADDLE_ENFORCE_NOT_NULL(sink_var); PADDLE_ENFORCE_NOT_NULL(sink_var);
res_dims.emplace_back(sink_var->GetShape()); res_dims.emplace_back(sink_var->GetShape());
res_lod_levels.push_back(sink_var->GetLoDLevel()); res_lod_levels.push_back(sink_var->GetLoDLevel());
...@@ -114,9 +114,11 @@ class CustomReaderInferVarType : public framework::VarTypeInference { ...@@ -114,9 +114,11 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
auto sink_var_names = auto sink_var_names =
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names")); boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names"));
const auto* sub_block =
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block"));
std::vector<framework::proto::VarType::Type> res_data_types; std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) { for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = block->FindVar(var_name); framework::VarDesc* var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
res_data_types.emplace_back(var->GetDataType()); res_data_types.emplace_back(var->GetDataType());
} }
...@@ -152,8 +154,7 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -152,8 +154,7 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
framework::Executor executor(dev_place_); framework::Executor executor(dev_place_);
framework::ProgramDesc* program = sub_block_.Program(); framework::ProgramDesc* program = sub_block_.Program();
framework::Scope* exe_scope = &scope_.NewScope(); framework::Scope* exe_scope = &scope_.NewScope();
executor.Run(*program, exe_scope, sub_block_.ID(), executor.Run(*program, exe_scope, sub_block_.ID(), false, true);
false /*create_local_scope*/, true);
scope_.DeleteScope(exe_scope); scope_.DeleteScope(exe_scope);
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
from .. import core from .. import core
from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program, Program from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program, Program
...@@ -21,7 +22,8 @@ from ..executor import global_scope ...@@ -21,7 +22,8 @@ from ..executor import global_scope
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer' 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'Preprocessor'
] ]
...@@ -468,8 +470,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None): ...@@ -468,8 +470,6 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
inputs={'UnderlyingReader': reader}, inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]}, outputs={'Out': [new_reader]},
attrs=attrs) attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader) return monkey_patch_reader_methods(new_reader)
...@@ -514,3 +514,81 @@ def read_file(file_obj): ...@@ -514,3 +514,81 @@ def read_file(file_obj):
return out[0] return out[0]
else: else:
return out return out
class Preprocessor(object):
BEFORE_SUB_BLOCK = 0
IN_SUB_BLOCK = 1
AFTER_SUB_BLOCK = 2
def __init__(self, reader, name=None):
self.underlying_reader = reader
new_reader_name = name if name is not None else unique_name(
"create_custom_reader")
self.main_prog = default_main_program()
self.reader = self.main_prog.current_block().create_var(
name=new_reader_name)
self.sub_block = None
self.source_var_names = None
self.sink_var_names = None
self.status = Preprocessor.BEFORE_SUB_BLOCK
def is_completed(self):
return self.sub_block and self.source_var_names and self.sink_var_names
@contextlib.contextmanager
def block(self):
self.status = Preprocessor.IN_SUB_BLOCK
self.sub_block = self.main_prog.create_block()
yield
self.main_prog.rollback()
self.status = Preprocessor.AFTER_SUB_BLOCK
if not self.is_completed():
raise RuntimeError(
"The definition of preprocessor is incompleted! "
"Please make sure that you have set input and output "
"variables by invoking 'inputs' and 'outputs' in "
"Preprocessor's sub-block.")
def inputs(self):
if self.status != Preprocessor.IN_SUB_BLOCK:
raise RuntimeError(
"Preprocessor.inputs() can only be invoked inside the sub-block."
)
source_shapes = self.underlying_reader.desc.shapes()
source_dtypes = self.underlying_reader.desc.dtypes()
source_lod_levels = self.underlying_reader.desc.lod_levels()
self.source_var_names = []
source_vars = []
for idx in xrange(len(source_shapes)):
self.source_var_names.append(unique_name("preprocessor_source"))
source_vars.append(self.main_prog.current_block().create_var(
name=self.source_var_names[-1],
shape=source_shapes[idx],
dtype=source_dtypes[idx],
lod_level=source_lod_levels[idx]))
return source_vars
def outputs(self, *outs):
if self.status != Preprocessor.IN_SUB_BLOCK:
raise RuntimeError(
"Preprocessor.outputs() can only be invoked inside the sub-block."
)
self.sink_var_names = [var.name for var in outs]
def __call__(self, *args, **kwargs):
if self.status != Preprocessor.AFTER_SUB_BLOCK:
raise RuntimeError(
"Preprocessor output can only be retrieved after rnn block.")
self.main_prog.current_block().append_op(
type="create_custom_reader",
inputs={'UnderlyingReader': self.underlying_reader},
outputs={'Out': [self.reader]},
attrs={
"sub_block": self.sub_block,
"source_var_names": self.source_var_names,
"sink_var_names": self.sink_var_names
})
return monkey_patch_reader_methods(self.reader)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist
class TestPreprocessor(unittest.TestCase):
def setUp(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist_for_preprocessor_test.recordio', reader, feeder)
def test_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.io.open_recordio_file(
'./mnist_for_preprocessor_test.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
preprocessor = fluid.layers.io.Preprocessor(reader=data_file)
with preprocessor.block():
img, lbl = preprocessor.inputs()
img_out = img / 2
lbl_out = lbl + 1
preprocessor.outputs(img_out, lbl_out)
img_before, lbl_before = fluid.layers.io.read_file(data_file)
img_after, lbl_after = fluid.layers.io.read_file(preprocessor())
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for _ in range(5):
img_b, lbl_b, img_a, lbl_a = exe.run(
fetch_list=[img_before, lbl_before, img_after, lbl_after])
self.assertEqual(img_b / 2, img_a)
self.assertEqual(lbl_b + 1, lbl_a)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册