提交 81470635 编写于 作者: F fengjiayi

follow comments

上级 0457f064
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -148,35 +149,31 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -148,35 +149,31 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// There is not next data. // There is not next data.
return; return;
} }
PADDLE_ENFORCE( PADDLE_ENFORCE(source_var_names_.size() == underlying_outs.size(),
source_var_names_.size() == underlying_outs.size() && "The size of source_var_names(%d) and the size of "
sink_var_names_.size() == underlying_outs.size(), "underlying_outs(%d) are not consistent. Each feeding element "
"The size of source_var_names(%d), the size of sink_var_names(%d) and " "must have its own source variable.",
"the size of underlying_outs(%d) are not consistent. Each feeding " source_var_names_.size(), underlying_outs.size());
"element must have its own source and sink variable.",
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
// The scope for CustomReader's sub-block should be independent and shouldn't // The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and // be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent. // compution cannot be concurrent.
auto* scope = new framework::Scope(); framework::Scope scope;
// 1. Copy LoDTensors from underlying reader's output to source variables. // 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) { for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope->Var(source_var_names_[i]); framework::Variable* var = scope.Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]); tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod()); tensor->set_lod(underlying_outs[i].lod());
} }
// 2. Run the sub-block. // 2. Run the sub-block.
exe_.Run(program_, scope, sub_block_id_, false, true); exe_.Run(program_, &scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) { for (size_t i = 0; i < sink_var_names_.size(); ++i) {
framework::Variable* var = scope->FindVar(sink_var_names_[i]); const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
PADDLE_ENFORCE_NOT_NULL(var); .Get<framework::LoDTensor>();
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
} }
delete scope;
} }
} // namespace reader } // namespace reader
......
...@@ -559,15 +559,16 @@ class Preprocessor(object): ...@@ -559,15 +559,16 @@ class Preprocessor(object):
source_shapes = self.underlying_reader.desc.shapes() source_shapes = self.underlying_reader.desc.shapes()
source_dtypes = self.underlying_reader.desc.dtypes() source_dtypes = self.underlying_reader.desc.dtypes()
source_lod_levels = self.underlying_reader.desc.lod_levels() source_lod_levels = self.underlying_reader.desc.lod_levels()
self.source_var_names = [] self.source_var_names = [
unique_name("preprocessor_source")
for _ in xrange(len(source_shapes))
]
source_vars = [] source_vars = []
for idx in xrange(len(source_shapes)): for var_name, shape, dtype, lod_level in zip(
self.source_var_names.append(unique_name("preprocessor_source")) self.source_var_names, source_shapes, source_dtypes,
source_lod_levels):
source_vars.append(self.main_prog.current_block().create_var( source_vars.append(self.main_prog.current_block().create_var(
name=self.source_var_names[-1], name=var_name, shape=shape, dtype=dtype, lod_level=lod_level))
shape=source_shapes[idx],
dtype=source_dtypes[idx],
lod_level=source_lod_levels[idx]))
return source_vars return source_vars
def outputs(self, *outs): def outputs(self, *outs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册