提交 44d5f42a 编写于 作者: F fengjiayi

update reader

上级 a4e437d5
......@@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
}
......
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
......@@ -98,10 +97,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto place_str = Attr<std::string>("place");
platform::Place place;
......
......@@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto& out = detail::Ref(scope.FindVar(Output("Out")));
int pass_num = Attr<int>("pass_num");
out.GetMutable<framework::ReaderHolder>()->Reset(
new MultiPassReader(underlying_reader.Get(), pass_num));
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
}
};
......
......@@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto& var = detail::Ref(scope.FindVar(Output("Out")));
var.GetMutable<framework::ReaderHolder>()->Reset(
out->Reset(
new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size"))));
}
......
......@@ -640,6 +640,21 @@ class Operator(object):
"""
return self.desc.block_attr(name)
@property
def attrs(self):
"""
Get the attribute dict
Returns(dict): The Operator's attribute dict
"""
attr_names = self.attr_names
attr_map = {}
for n in attr_names:
if n == 'sub_block':
attr_map[n] = self.block_attr(n)
else:
attr_map[n] = self.attr(n)
return attr_map
class Block(object):
def __init__(self, program, idx):
......
......@@ -255,7 +255,22 @@ def _copy_reader_var_(block, var):
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return monkey_patch_reader_methods(new_var)
return new_var
def _copy_reader_create_op_(block, op):
def _find_vars_(block, name_list):
res = {}
for n in name_list:
var = block.var(n)
res[n] = var
return res
input_map = _find_vars_(block, op.input_names)
output_map = _find_vars_(block, op.output_names)
new_op = block.append_op(
type=op.type, inputs=input_map, outputs=output_map, attrs=op.attrs)
return new_op
def open_recordio_file(filename, shapes, lod_levels, dtypes):
......@@ -283,8 +298,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var)
def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
......@@ -313,22 +329,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var)
return monkey_patch_reader_methods(main_prog_var)
def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
startop_op = startup_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [startup_var]},
attrs=attrs)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)
main_prog_block = default_main_program().current_block()
main_prog_var = _copy_reader_var_(main_prog_block, startup_var)
_copy_reader_create_op_(main_prog_block, startop_op)
return monkey_patch_reader_methods(main_prog_var)
def create_shuffle_reader(reader, buffer_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册