提交 44d5f42a 编写于 作者: F fengjiayi

update reader

上级 a4e437d5
...@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset( out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"))); new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
} }
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <thread>
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
...@@ -98,10 +97,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -98,10 +97,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto place_str = Attr<std::string>("place"); auto place_str = Attr<std::string>("place");
platform::Place place; platform::Place place;
......
...@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto& out = detail::Ref(scope.FindVar(Output("Out")));
int pass_num = Attr<int>("pass_num"); int pass_num = Attr<int>("pass_num");
out.GetMutable<framework::ReaderHolder>()->Reset( out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
new MultiPassReader(underlying_reader.Get(), pass_num));
} }
}; };
......
...@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
auto& var = detail::Ref(scope.FindVar(Output("Out"))); out->Reset(
var.GetMutable<framework::ReaderHolder>()->Reset(
new ShuffleReader(underlying_reader.Get(), new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size")))); static_cast<size_t>(Attr<int>("buffer_size"))));
} }
......
...@@ -640,6 +640,21 @@ class Operator(object): ...@@ -640,6 +640,21 @@ class Operator(object):
""" """
return self.desc.block_attr(name) return self.desc.block_attr(name)
@property
def attrs(self):
"""
Get the attribute dict
Returns(dict): The Operator's attribute dict
"""
attr_names = self.attr_names
attr_map = {}
for n in attr_names:
if n == 'sub_block':
attr_map[n] = self.block_attr(n)
else:
attr_map[n] = self.attr(n)
return attr_map
class Block(object): class Block(object):
def __init__(self, program, idx): def __init__(self, program, idx):
......
...@@ -255,7 +255,22 @@ def _copy_reader_var_(block, var): ...@@ -255,7 +255,22 @@ def _copy_reader_var_(block, var):
new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes()) new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True new_var.persistable = True
return monkey_patch_reader_methods(new_var) return new_var
def _copy_reader_create_op_(block, op):
def _find_vars_(block, name_list):
res = {}
for n in name_list:
var = block.var(n)
res[n] = var
return res
input_map = _find_vars_(block, op.input_names)
output_map = _find_vars_(block, op.output_names)
new_op = block.append_op(
type=op.type, inputs=input_map, outputs=output_map, attrs=op.attrs)
return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes): def open_recordio_file(filename, shapes, lod_levels, dtypes):
...@@ -283,8 +298,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): ...@@ -283,8 +298,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.desc.set_dtypes(dtypes) startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes): def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
...@@ -313,22 +329,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): ...@@ -313,22 +329,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
startup_var.desc.set_dtypes(dtypes) startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
return monkey_patch_reader_methods(main_prog_var)
def __create_decorated_reader__(op_type, reader, attrs): def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op( startop_op = startup_blk.append_op(
type=op_type, type=op_type,
inputs={'UnderlyingReader': reader}, inputs={'UnderlyingReader': reader},
outputs={'Out': [startup_var]}, outputs={'Out': [startup_var]},
attrs=attrs) attrs=attrs)
startup_var.persistable = True startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(), main_prog_block = default_main_program().current_block()
startup_var) main_prog_var = _copy_reader_var_(main_prog_block, startup_var)
_copy_reader_create_op_(main_prog_block, startop_op)
return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size): def create_shuffle_reader(reader, buffer_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册