提交 239546a6 编写于 作者: F fengjiayi

add unit test and fix a bug

上级 2e42b31f
...@@ -22,12 +22,11 @@ namespace reader { ...@@ -22,12 +22,11 @@ namespace reader {
class CustomReader : public framework::DecoratedReader { class CustomReader : public framework::DecoratedReader {
public: public:
CustomReader(ReaderBase* reader, const framework::BlockDesc* sub_block, CustomReader(ReaderBase* reader, const framework::BlockDesc* sub_block,
const framework::Scope* scope, const platform::Place& dev_place, const platform::Place& dev_place,
const std::vector<std::string>& source_var_names, const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names) const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader), : DecoratedReader(reader),
sub_block_(sub_block), sub_block_(sub_block),
scope_(scope),
exe_(framework::Executor(dev_place)), exe_(framework::Executor(dev_place)),
source_var_names_(source_var_names), source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {} sink_var_names_(sink_var_names) {}
...@@ -37,12 +36,10 @@ class CustomReader : public framework::DecoratedReader { ...@@ -37,12 +36,10 @@ class CustomReader : public framework::DecoratedReader {
void UpdateBlockAndScope(const framework::BlockDesc* sub_block, void UpdateBlockAndScope(const framework::BlockDesc* sub_block,
const framework::Scope* scope) { const framework::Scope* scope) {
sub_block_ = sub_block; sub_block_ = sub_block;
scope_ = scope;
} }
private: private:
const framework::BlockDesc* sub_block_; const framework::BlockDesc* sub_block_;
const framework::Scope* scope_;
framework::Executor exe_; framework::Executor exe_;
std::vector<std::string> source_var_names_; std::vector<std::string> source_var_names_;
...@@ -67,7 +64,7 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -67,7 +64,7 @@ class CreateCustomReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(
new CustomReader(underlying_reader.Get(), sub_block, &scope, dev_place, new CustomReader(underlying_reader.Get(), sub_block, dev_place,
Attr<std::vector<std::string>>("source_var_names"), Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names"))); Attr<std::vector<std::string>>("sink_var_names")));
} }
...@@ -150,27 +147,29 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -150,27 +147,29 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
"the size of underlying_outs(%d) are not consistent. Each feeding " "the size of underlying_outs(%d) are not consistent. Each feeding "
"element must have its own source and sink variable.", "element must have its own source and sink variable.",
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size()); source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
// The scope for CustomReader's sub-block should be independent and shouldn't
framework::Scope* exe_scope = &scope_->NewScope(); // be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
auto* scope = new framework::Scope();
// 1. Copy LoDTensors from underlying reader's output to source variables. // 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) { for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = exe_scope->Var(source_var_names_[i]); framework::Variable* var = scope->Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]); tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod()); tensor->set_lod(underlying_outs[i].lod());
} }
// 2. Run the sub-block. // 2. Run the sub-block.
framework::ProgramDesc* program = sub_block_->Program(); framework::ProgramDesc* program = sub_block_->Program();
exe_.Run(*program, exe_scope, sub_block_->ID(), false, true); exe_.Run(*program, scope, sub_block_->ID(), false, true);
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) { for (size_t i = 0; i < sink_var_names_.size(); ++i) {
framework::Variable* var = exe_scope->FindVar(sink_var_names_[i]); framework::Variable* var = scope->FindVar(sink_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>(); const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
} }
scope_->DeleteScope(exe_scope); delete scope;
} }
} // namespace reader } // namespace reader
......
...@@ -74,7 +74,8 @@ class TestPreprocessor(unittest.TestCase): ...@@ -74,7 +74,8 @@ class TestPreprocessor(unittest.TestCase):
lbl_out = lbl + 1 lbl_out = lbl + 1
preprocessor.outputs(img_out, lbl_out) preprocessor.outputs(img_out, lbl_out)
img, lbl = fluid.layers.io.read_file(preprocessor()) data_file = fluid.layers.io.double_buffer(preprocessor())
img, lbl = fluid.layers.io.read_file(data_file)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册