提交 e15d616e 编写于 作者: F fengjiayi

Complete the C++ core of 'CustomReader'

上级 e61a38da
...@@ -63,6 +63,7 @@ class InferShapeContext { ...@@ -63,6 +63,7 @@ class InferShapeContext {
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name); std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name); std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name);
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
// Note: In while op, we need this to be public // Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names, void SetDims(const std::vector<std::string> &names,
...@@ -81,8 +82,6 @@ class InferShapeContext { ...@@ -81,8 +82,6 @@ class InferShapeContext {
const std::vector<std::string> &names) const; const std::vector<std::string> &names) const;
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0; virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o ...@@ -23,6 +23,7 @@ reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_o
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc) reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc) cc_test(reader_blocking_queue_test SRCS reader_blocking_queue_test.cc)
# Export local libraries to parent # Export local libraries to parent
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -77,29 +78,101 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -77,29 +78,101 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { class CustomReaderInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE_EQ( public:
source_var_names_.size(), out->size(), void operator()(framework::InferShapeContext* ctx) const override {
"The size of source_var_names(%d) not equals to the size of 'out'(%d). " PADDLE_ENFORCE(!ctx->IsRuntime(),
"Each element of 'out' must have its own source var in the CustomReader.", "'CustomReaderInferShape' should only be invoked during "
source_var_names_.size(), out->size()); "compile time.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE(ctx->HasOutput("Out"),
sink_var_names_.size(), out->size(), "The output decorated reader should not be null.");
"The size of sink_var_names(%d) not equals to the size of 'out'(%d). " const auto sink_var_names =
"Each element of 'out' must have its own sink var in the CustomReader.", ctx->Attrs().Get<std::vector<std::string>>("sink_var_names");
sink_var_names_.size(), out->size()); std::vector<std::vector<int64_t>> res_dims;
std::vector<int32_t> res_lod_levels;
for (const std::string& var_name : sink_var_names) {
auto* sink_var =
boost::get<framework::VarDesc*>(ctx->GetVarPtr(var_name));
PADDLE_ENFORCE_NOT_NULL(sink_var);
res_dims.emplace_back(sink_var->GetShape());
res_lod_levels.push_back(sink_var->GetLoDLevel());
}
auto* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetShapes(res_dims);
out_reader->SetLoDLevels(res_lod_levels);
}
};
class CustomReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]);
PADDLE_ENFORCE_NOT_NULL(out_reader);
out_reader->SetType(framework::proto::VarType::READER);
auto sink_var_names =
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names"));
std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
res_data_types.emplace_back(var->GetDataType());
}
out_reader->SetDataTypes(res_data_types);
}
};
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs);
if (underlying_outs.empty()) {
// There is not next data.
return;
}
PADDLE_ENFORCE(
source_var_names_.size() == underlying_outs.size() &&
sink_var_names_.size() == underlying_outs.size(),
"The size of source_var_names(%d), the size of sink_var_names(%d) and "
"the size of underlying_outs(%d) are not consistent. Each feeding "
"element must have its own source and sink variable.",
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) { for (size_t i = 0; i < source_var_names_.size(); ++i) {
const std::string& var_name = source_var_names_[i]; framework::Variable* var = scope_.FindVar(source_var_names_[i]);
framework::Variable* var = scope_.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, "CustomReader's source variable '%s' doesn't exist."); var, "CustomReader's source variable '%s' doesn't exist.");
framework::LoDTensor* tensor = var->GetMutable<framework::loDtensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
} }
// TODO(fengjiayi): 将vector中的数据拷贝到sorce_var和sink_var中 // 2. Run the sub-block.
framework::Executor executor(dev_place_); framework::Executor executor(dev_place_);
framework::ProgramDesc* program = sub_block_.Program();
framework::Scope* exe_scope = &scope_.NewScope();
executor.Run(*program, exe_scope, sub_block_.ID(),
false /*create_local_scope*/, true);
scope_.DeleteScope(exe_scope);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
framework::Variable* var = scope_.FindVar(sink_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(var,
"CustomReader's sink variable '%s' doesn't exist.");
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
(*out)[i].ShareDataWith(tensor);
(*out)[i].set_lod(tensor.lod());
}
} }
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_OPERATOR(create_custom_reader, ops::CreateCustomReaderOp,
ops::CreateCustomReaderOpMaker, ops::CustomReaderInferShape,
ops::CustomReaderInferVarType,
paddle::framework::EmptyGradOpMaker)
...@@ -117,6 +117,7 @@ void DecoratedReaderInferShape::operator()( ...@@ -117,6 +117,7 @@ void DecoratedReaderInferShape::operator()(
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]); boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels()); out_reader->SetLoDLevels(in_reader->GetLoDLevels());
} }
void DecoratedReaderInferVarType::operator()( void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const { const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0]; std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册