提交 5416bac5 编写于 作者: F fengjiayi

Make shared decorated readers' creater be only in main_program

上级 3f90a583
......@@ -109,7 +109,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
if (place_str == "AUTO") {
place = dev_place;
} else if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
......@@ -140,8 +142,9 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
enum_range.insert("AUTO");
AddAttr<std::string>("place", "The double buffer place")
.SetDefault("AUTO")
.InEnum({enum_range});
}
};
......
......@@ -440,7 +440,7 @@ def open_files(filenames,
return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs={}):
def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
......@@ -456,26 +456,40 @@ def __create_decorated_reader__(op_type, reader, attrs={}):
return monkey_patch_reader_methods(main_prog_var)
def __create_shared_decorated_reader__(op_type, reader, attrs={}):
new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)
def shuffle(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def double_buffer(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
return __create_unshared_decorated_reader__('create_double_buffer_reader',
reader, attrs)
def multi_pass(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
return __create_shared_decorated_reader__(
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def for_parallel(reader):
return __create_decorated_reader__('create_threaded_reader', reader)
return __create_shared_decorated_reader__('create_threaded_reader', reader)
def read_file(file_obj):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册