diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 46c8feec001584a872f7f62682080e0e72c06f50..5f497cafa0f75f7c23d550ef767d55274de7c900 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -63,6 +63,7 @@ class InferShapeContext { std::vector GetInputVarPtrs(const std::string &name); std::vector GetOutputVarPtrs(const std::string &name); + virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; // Note: In while op, we need this to be public void SetDims(const std::vector &names, @@ -81,8 +82,6 @@ class InferShapeContext { const std::vector &names) const; virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; - - virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; }; } // namespace framework diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 3106978eb0149b14849dfd1aaad8bbe76791f2f6..62532036f86bfb82465ccd9e0ec526299489932a 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc) +reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) # Export local libraries to parent diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index 6f81075dd7df5dadb16862aec67b4a9236a4e300..e35775ed18b8e304c8f861f4fd28df27d813e8a5 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { @@ -77,29 +78,101 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { } }; -void CustomReader::ReadNext(std::vector* out) { - PADDLE_ENFORCE_EQ( - source_var_names_.size(), out->size(), - "The size of source_var_names(%d) not equals to the size of 'out'(%d). " - "Each element of 'out' must have its own source var in the CustomReader.", - source_var_names_.size(), out->size()); - PADDLE_ENFORCE_EQ( - sink_var_names_.size(), out->size(), - "The size of sink_var_names(%d) not equals to the size of 'out'(%d). " - "Each element of 'out' must have its own sink var in the CustomReader.", - sink_var_names_.size(), out->size()); +class CustomReaderInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(!ctx->IsRuntime(), + "'CustomReaderInferShape' should only be invoked during " + "compile time."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "The output decorated reader should not be null."); + const auto sink_var_names = + ctx->Attrs().Get>("sink_var_names"); + std::vector> res_dims; + std::vector res_lod_levels; + for (const std::string& var_name : sink_var_names) { + auto* sink_var = + boost::get(ctx->GetVarPtr(var_name)); + PADDLE_ENFORCE_NOT_NULL(sink_var); + res_dims.emplace_back(sink_var->GetShape()); + res_lod_levels.push_back(sink_var->GetLoDLevel()); + } + auto* out_reader = + boost::get(ctx->GetOutputVarPtrs("Out")[0]); + out_reader->SetShapes(res_dims); + out_reader->SetLoDLevels(res_lod_levels); + } +}; + +class CustomReaderInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]); + PADDLE_ENFORCE_NOT_NULL(out_reader); + out_reader->SetType(framework::proto::VarType::READER); + auto sink_var_names = + boost::get>(op_desc.GetAttr("sink_var_names")); + std::vector res_data_types; + for (const std::string& var_name : sink_var_names) { + framework::VarDesc* var = block->FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var); + res_data_types.emplace_back(var->GetDataType()); + } + out_reader->SetDataTypes(res_data_types); + } +}; + +void CustomReader::ReadNext(std::vector* out) { + out->clear(); + std::vector underlying_outs; + reader_->ReadNext(&underlying_outs); + if (underlying_outs.empty()) { + // There is not next data. + return; + } + PADDLE_ENFORCE( + source_var_names_.size() == underlying_outs.size() && + sink_var_names_.size() == underlying_outs.size(), + "The size of source_var_names(%d), the size of sink_var_names(%d) and " + "the size of underlying_outs(%d) are not consistent. Each feeding " + "element must have its own source and sink variable.", + source_var_names_.size(), sink_var_names_.size(), underlying_outs.size()); + // 1. Copy LoDTensors from underlying reader's output to source variables. for (size_t i = 0; i < source_var_names_.size(); ++i) { - const std::string& var_name = source_var_names_[i]; - framework::Variable* var = scope_.FindVar(var_name); + framework::Variable* var = scope_.FindVar(source_var_names_[i]); PADDLE_ENFORCE_NOT_NULL( var, "CustomReader's source variable '%s' doesn't exist."); - framework::LoDTensor* tensor = var->GetMutable(); + framework::LoDTensor* tensor = var->GetMutable(); + tensor->ShareDataWith(underlying_outs[i]); + tensor->set_lod(underlying_outs[i].lod()); } - // TODO(fengjiayi): 将vector中的数据拷贝到sorce_var和sink_var中 + // 2. Run the sub-block. framework::Executor executor(dev_place_); + framework::ProgramDesc* program = sub_block_.Program(); + framework::Scope* exe_scope = &scope_.NewScope(); + executor.Run(*program, exe_scope, sub_block_.ID(), + false /*create_local_scope*/, true); + scope_.DeleteScope(exe_scope); + // 3. Copy LoDTensors from sink variables to out. + out->resize(sink_var_names_.size()); + for (size_t i = 0; i < sink_var_names_.size(); ++i) { + framework::Variable* var = scope_.FindVar(sink_var_names_[i]); + PADDLE_ENFORCE_NOT_NULL(var, + "CustomReader's sink variable '%s' doesn't exist."); + const framework::LoDTensor& tensor = var->Get(); + (*out)[i].ShareDataWith(tensor); + (*out)[i].set_lod(tensor.lod()); + } } } // namespace reader } // namespace operators } // namespace paddle + +namespace ops = paddle::operators::reader; +REGISTER_OPERATOR(create_custom_reader, ops::CreateCustomReaderOp, + ops::CreateCustomReaderOpMaker, ops::CustomReaderInferShape, + ops::CustomReaderInferVarType, + paddle::framework::EmptyGradOpMaker) diff --git a/paddle/fluid/operators/reader/reader_op_registry.cc b/paddle/fluid/operators/reader/reader_op_registry.cc index 3ff4536819b128d9c593b97f4942a0292a3b6b36..52adc54dc22f60280348060ee535b937a3f58263 100644 --- a/paddle/fluid/operators/reader/reader_op_registry.cc +++ b/paddle/fluid/operators/reader/reader_op_registry.cc @@ -117,6 +117,7 @@ void DecoratedReaderInferShape::operator()( boost::get(ctx->GetOutputVarPtrs("Out")[0]); out_reader->SetLoDLevels(in_reader->GetLoDLevels()); } + void DecoratedReaderInferVarType::operator()( const framework::OpDesc& op_desc, framework::BlockDesc* block) const { std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];