未验证 提交 3874c383 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #9596 from JiayiFeng/update_reader

Update reader
...@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset( out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"))); new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
} }
......
...@@ -99,10 +99,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -99,10 +99,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto place_str = Attr<std::string>("place"); auto place_str = Attr<std::string>("place");
platform::Place place; platform::Place place;
......
...@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto& out = detail::Ref(scope.FindVar(Output("Out")));
int pass_num = Attr<int>("pass_num"); int pass_num = Attr<int>("pass_num");
out.GetMutable<framework::ReaderHolder>()->Reset( out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
new MultiPassReader(underlying_reader.Get(), pass_num));
} }
}; };
......
...@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto& var = detail::Ref(scope.FindVar(Output("Out"))); out->Reset(
var.GetMutable<framework::ReaderHolder>()->Reset(
new ShuffleReader(underlying_reader.Get(), new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size")))); static_cast<size_t>(Attr<int>("buffer_size"))));
} }
......
...@@ -640,6 +640,20 @@ class Operator(object): ...@@ -640,6 +640,20 @@ class Operator(object):
""" """
return self.desc.block_attr(name) return self.desc.block_attr(name)
def all_attrs(self):
"""
Get the attribute dict
Returns(dict): The Operator's attribute dict
"""
attr_names = self.attr_names
attr_map = {}
for n in attr_names:
if n == 'sub_block':
attr_map[n] = self.block_attr(n)
else:
attr_map[n] = self.attr(n)
return attr_map
class Block(object): class Block(object):
def __init__(self, program, idx): def __init__(self, program, idx):
......
...@@ -255,7 +255,32 @@ def _copy_reader_var_(block, var): ...@@ -255,7 +255,32 @@ def _copy_reader_var_(block, var):
new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes()) new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True new_var.persistable = True
return monkey_patch_reader_methods(new_var) return new_var
def _copy_reader_create_op_(block, op):
input_param_names = op.input_names
new_input_map = {}
for param_name in input_param_names:
new_input_map[param_name] = []
arg_names = op.input(param_name)
for arg_name in arg_names:
new_input_map[param_name].append(block.var(arg_name))
output_param_names = op.output_names
new_output_map = {}
for param_name in output_param_names:
new_output_map[param_name] = []
arg_names = op.output(param_name)
for arg_name in arg_names:
new_output_map[param_name].append(block.var(arg_name))
new_op = block.append_op(
type=op.type,
inputs=new_input_map,
outputs=new_output_map,
attrs=op.all_attrs())
return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes): def open_recordio_file(filename, shapes, lod_levels, dtypes):
...@@ -283,8 +308,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -283,8 +308,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.desc.set_dtypes(dtypes) startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes): def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
...@@ -313,22 +339,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): ...@@ -313,22 +339,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
startup_var.desc.set_dtypes(dtypes) startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
return monkey_patch_reader_methods(main_prog_var)
def __create_decorated_reader__(op_type, reader, attrs): def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op( startop_op = startup_blk.append_op(
type=op_type, type=op_type,
inputs={'UnderlyingReader': reader}, inputs={'UnderlyingReader': reader},
outputs={'Out': [startup_var]}, outputs={'Out': [startup_var]},
attrs=attrs) attrs=attrs)
startup_var.persistable = True startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(), main_prog_block = default_main_program().current_block()
startup_var) main_prog_var = _copy_reader_var_(main_prog_block, startup_var)
_copy_reader_create_op_(main_prog_block, startop_op)
return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size): def create_shuffle_reader(reader, buffer_size):
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle.v2 as paddle
import paddle.dataset.mnist as mnist import paddle.v2.dataset.mnist as mnist
class TestRecordIO(unittest.TestCase): class TestRecordIO(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册