From 5416bac5d84b1d846744481505749df0a87db133 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 9 Apr 2018 15:47:46 +0800 Subject: [PATCH] Make shared decorated readers' creater be only in main_program --- .../reader/create_double_buffer_reader_op.cc | 9 ++++-- python/paddle/fluid/layers/io.py | 30 ++++++++++++++----- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index ed868786ab2..d9f799f14d1 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -109,7 +109,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { auto place_str = Attr("place"); platform::Place place; - if (place_str == "CPU") { + if (place_str == "AUTO") { + place = dev_place; + } else if (place_str == "CPU") { place = platform::CPUPlace(); } else { std::istringstream sin(place_str); @@ -140,8 +142,9 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { enum_range.insert(string::Sprintf("CUDA:%d", i)); } enum_range.insert("CPU"); - AddAttr("place", "The double buffer place, default is CPU") - .SetDefault("CPU") + enum_range.insert("AUTO"); + AddAttr("place", "The double buffer place") + .SetDefault("AUTO") .InEnum({enum_range}); } }; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index dbba1a46eb6..4901521db5b 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -440,7 +440,7 @@ def open_files(filenames, return monkey_patch_reader_methods(main_prog_reader) -def __create_decorated_reader__(op_type, reader, attrs={}): +def __create_unshared_decorated_reader__(op_type, reader, attrs={}): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() startup_var = startup_blk.create_var(name=var_name) @@ -456,26 +456,40 @@ def __create_decorated_reader__(op_type, reader, attrs={}): return monkey_patch_reader_methods(main_prog_var) +def __create_shared_decorated_reader__(op_type, reader, attrs={}): + new_reader_name = unique_name(op_type) + main_blk = default_main_program().current_block() + new_reader = main_blk.create_var(name=new_reader_name) + main_blk.append_op( + type=op_type, + inputs={'UnderlyingReader': reader}, + outputs={'Out': [new_reader]}, + attrs=attrs) + new_reader.persistable = True + new_reader.stop_gradient = True + return monkey_patch_reader_methods(new_reader) + + def shuffle(reader, buffer_size): - return __create_decorated_reader__('create_shuffle_reader', reader, - {'buffer_size': int(buffer_size)}) + return __create_unshared_decorated_reader__( + 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) def double_buffer(reader, place=None): attrs = dict() if place is not None: attrs['place'] = str(place).upper() - return __create_decorated_reader__('create_double_buffer_reader', reader, - attrs) + return __create_unshared_decorated_reader__('create_double_buffer_reader', + reader, attrs) def multi_pass(reader, pass_num): - return __create_decorated_reader__('create_multi_pass_reader', reader, - {'pass_num': int(pass_num)}) + return __create_shared_decorated_reader__( + 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) def for_parallel(reader): - return __create_decorated_reader__('create_threaded_reader', reader) + return __create_shared_decorated_reader__('create_threaded_reader', reader) def read_file(file_obj): -- GitLab