提交 4b395b08 编写于 作者: F fengjiayi

fix errors

上级 df8fbf80
......@@ -21,8 +21,8 @@ namespace reader {
class CustomReader : public framework::DecoratedReader {
public:
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
const framework::Scope& scope, const platform::Place& dev_place,
CustomReader(ReaderBase* reader, const framework::BlockDesc* sub_block,
const framework::Scope* scope, const platform::Place& dev_place,
const std::vector<std::string>& source_var_names,
const std::vector<std::string>& sink_var_names)
: DecoratedReader(reader),
......@@ -34,9 +34,15 @@ class CustomReader : public framework::DecoratedReader {
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void UpdateBlockAndScope(const framework::BlockDesc* sub_block,
const framework::Scope* scope) {
sub_block_ = sub_block;
scope_ = scope;
}
private:
const framework::BlockDesc& sub_block_;
const framework::Scope& scope_;
const framework::BlockDesc* sub_block_;
const framework::Scope* scope_;
platform::Place dev_place_;
std::vector<std::string> source_var_names_;
......@@ -52,15 +58,18 @@ class CreateCustomReaderOp : public framework::OperatorBase {
const platform::Place& dev_place) const override {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
auto* sub_block = Attr<framework::BlockDesc*>("sub_block");
if (out->Get() != nullptr) {
auto* custom_reader = reinterpret_cast<CustomReader*>(out->Get());
custom_reader->UpdateBlockAndScope(sub_block, &scope);
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(new CustomReader(
underlying_reader.Get(), *Attr<framework::BlockDesc*>("sub_block"),
scope, dev_place, Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
out->Reset(
new CustomReader(underlying_reader.Get(), sub_block, &scope, dev_place,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
};
......@@ -141,31 +150,28 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
"the size of underlying_outs(%d) are not consistent. Each feeding "
"element must have its own source and sink variable.",
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
framework::Scope* exe_scope = &scope_->NewScope();
// 1. Copy LoDTensors from underlying reader's output to source variables.
for (size_t i = 0; i < source_var_names_.size(); ++i) {
framework::Variable* var = scope_.FindVar(source_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(
var, "CustomReader's source variable '%s' doesn't exist.");
framework::Variable* var = exe_scope->Var(source_var_names_[i]);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(underlying_outs[i]);
tensor->set_lod(underlying_outs[i].lod());
}
// 2. Run the sub-block.
framework::Executor executor(dev_place_);
framework::ProgramDesc* program = sub_block_.Program();
framework::Scope* exe_scope = &scope_.NewScope();
executor.Run(*program, exe_scope, sub_block_.ID(), false, true);
scope_.DeleteScope(exe_scope);
framework::ProgramDesc* program = sub_block_->Program();
executor.Run(*program, exe_scope, sub_block_->ID(), false, true);
// 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
framework::Variable* var = scope_.FindVar(sink_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(var,
"CustomReader's sink variable '%s' doesn't exist.");
framework::Variable* var = exe_scope->FindVar(sink_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(var);
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
(*out)[i].ShareDataWith(tensor);
(*out)[i].set_lod(tensor.lod());
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
}
scope_->DeleteScope(exe_scope);
}
} // namespace reader
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.v2 as paddle
......@@ -35,6 +36,31 @@ class TestPreprocessor(unittest.TestCase):
'./mnist_for_preprocessor_test.recordio', reader, feeder)
def test_main(self):
N = 10
img_expected_res = []
lbl_expected_res = []
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.io.open_recordio_file(
'./mnist_for_preprocessor_test.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, lbl = fluid.layers.io.read_file(data_file)
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for _ in range(N):
img_v, lbl_v = exe.run(fetch_list=[img, lbl])
img_expected_res.append(img_v / 2)
lbl_expected_res.append(lbl_v + 1)
img_actual_res = []
lbl_actual_res = []
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.io.open_recordio_file(
'./mnist_for_preprocessor_test.recordio',
......@@ -48,8 +74,7 @@ class TestPreprocessor(unittest.TestCase):
lbl_out = lbl + 1
preprocessor.outputs(img_out, lbl_out)
img_before, lbl_before = fluid.layers.io.read_file(data_file)
img_after, lbl_after = fluid.layers.io.read_file(preprocessor())
img, lbl = fluid.layers.io.read_file(preprocessor())
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
......@@ -57,10 +82,11 @@ class TestPreprocessor(unittest.TestCase):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for _ in range(N):
img_v, lbl_v = exe.run(fetch_list=[img, lbl])
img_actual_res.append(img_v)
lbl_actual_res.append(lbl_v)
for _ in range(5):
img_b, lbl_b, img_a, lbl_a = exe.run(
fetch_list=[img_before, lbl_before, img_after, lbl_after])
self.assertEqual(img_b / 2, img_a)
self.assertEqual(lbl_b + 1, lbl_a)
for idx in range(N):
np.allclose(img_expected_res[idx], img_actual_res[idx])
np.allclose(lbl_expected_res[idx], lbl_actual_res[idx])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册