提交 5416bac5 编写于 作者: F fengjiayi

Make shared decorated readers' creater be only in main_program

上级 3f90a583
...@@ -109,7 +109,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -109,7 +109,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto place_str = Attr<std::string>("place"); auto place_str = Attr<std::string>("place");
platform::Place place; platform::Place place;
if (place_str == "CPU") { if (place_str == "AUTO") {
place = dev_place;
} else if (place_str == "CPU") {
place = platform::CPUPlace(); place = platform::CPUPlace();
} else { } else {
std::istringstream sin(place_str); std::istringstream sin(place_str);
...@@ -140,8 +142,9 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -140,8 +142,9 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range.insert(string::Sprintf("CUDA:%d", i)); enum_range.insert(string::Sprintf("CUDA:%d", i));
} }
enum_range.insert("CPU"); enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU") enum_range.insert("AUTO");
.SetDefault("CPU") AddAttr<std::string>("place", "The double buffer place")
.SetDefault("AUTO")
.InEnum({enum_range}); .InEnum({enum_range});
} }
}; };
......
...@@ -440,7 +440,7 @@ def open_files(filenames, ...@@ -440,7 +440,7 @@ def open_files(filenames,
return monkey_patch_reader_methods(main_prog_reader) return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs={}): def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
...@@ -456,26 +456,40 @@ def __create_decorated_reader__(op_type, reader, attrs={}): ...@@ -456,26 +456,40 @@ def __create_decorated_reader__(op_type, reader, attrs={}):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def __create_shared_decorated_reader__(op_type, reader, attrs={}):
new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)
def shuffle(reader, buffer_size): def shuffle(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader, return __create_unshared_decorated_reader__(
{'buffer_size': int(buffer_size)}) 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def double_buffer(reader, place=None): def double_buffer(reader, place=None):
attrs = dict() attrs = dict()
if place is not None: if place is not None:
attrs['place'] = str(place).upper() attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader, return __create_unshared_decorated_reader__('create_double_buffer_reader',
attrs) reader, attrs)
def multi_pass(reader, pass_num): def multi_pass(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader, return __create_shared_decorated_reader__(
{'pass_num': int(pass_num)}) 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def for_parallel(reader): def for_parallel(reader):
return __create_decorated_reader__('create_threaded_reader', reader) return __create_shared_decorated_reader__('create_threaded_reader', reader)
def read_file(file_obj): def read_file(file_obj):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册