提交 81470635 编写于 作者: F fengjiayi

follow comments

上级 0457f064
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
......@@ -148,35 +149,31 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
// There is not next data.
return;
}
PADDLE_ENFORCE(
source_var_names_.size() == underlying_outs.size() &&
sink_var_names_.size() == underlying_outs.size(),
"The size of source_var_names(%d), the size of sink_var_names(%d) and "
"the size of underlying_outs(%d) are not consistent. Each feeding "
"element must have its own source and sink variable.",
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
PADDLE_ENFORCE(source_var_names_.size() == underlying_outs.size(),
"The size of source_var_names(%d) and the size of "
"underlying_outs(%d) are not consistent. Each feeding element "
"must have its own source variable.",
source_var_names_.size(), underlying_outs.size());
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
auto* scope = new framework::Scope();
framework::Scope scope;
// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope->Var(source_var_names_[i]);
framework::Variable* var = scope.Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
exe_.Run(program_, scope, sub_block_id_, false, true);
exe_.Run(program_, &scope, sub_block_id_, false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
framework::Variable* var = scope->FindVar(sink_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(var);
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
.Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
}
delete scope;
}
} // namespace reader
......
......@@ -559,15 +559,16 @@ class Preprocessor(object):
source_shapes = self.underlying_reader.desc.shapes()
source_dtypes = self.underlying_reader.desc.dtypes()
source_lod_levels = self.underlying_reader.desc.lod_levels()
self.source_var_names = []
self.source_var_names = [
unique_name("preprocessor_source")
for _ in xrange(len(source_shapes))
]
source_vars = []
for idx in xrange(len(source_shapes)):
self.source_var_names.append(unique_name("preprocessor_source"))
for var_name, shape, dtype, lod_level in zip(
self.source_var_names, source_shapes, source_dtypes,
source_lod_levels):
source_vars.append(self.main_prog.current_block().create_var(
name=self.source_var_names[-1],
shape=source_shapes[idx],
dtype=source_dtypes[idx],
lod_level=source_lod_levels[idx]))
name=var_name, shape=shape, dtype=dtype, lod_level=lod_level))
return source_vars
def outputs(self, *outs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册