提交 ee178d5a 编写于 作者: J JiayiFeng

fix bugs

上级 fe48dfcb
...@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &var : vars) { for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var); auto *main_var = main_scope->FindVar(var);
if (!main_var->IsType<LoDTensor>()) { if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue; continue;
} }
auto &main_tensor = main_var->Get<LoDTensor>(); auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims(); auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
......
...@@ -66,13 +66,6 @@ class ReadOp : public framework::OperatorBase { ...@@ -66,13 +66,6 @@ class ReadOp : public framework::OperatorBase {
std::vector<std::string> out_arg_names = Outputs("Out"); std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE(
!ins.empty(),
"Reader can not read the next data even it has been re-initialized.");
}
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
auto* out = auto* out =
......
...@@ -111,7 +111,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -111,7 +111,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
"When 'unsafe_mode' is false, invoking 'HasNext()' or " "When 'unsafe_mode' is false, invoking 'HasNext()' or "
"'ReInit()' is not allowed to avoid unexpected bugs in " "'ReInit()' is not allowed to avoid unexpected bugs in "
"multi-thread environment.") "multi-thread environment.")
.SetDefault(false); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CreateThreadedReader Operator CreateThreadedReader Operator
...@@ -134,6 +134,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -134,6 +134,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
} // namespace paddle } // namespace paddle
namespace reader = paddle::operators::reader; namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_threaded_reader, REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp, reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker); reader::CreateThreadedReaderOpMaker);
...@@ -350,7 +350,7 @@ def open_recordio_file(filename, ...@@ -350,7 +350,7 @@ def open_recordio_file(filename,
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel: if for_parallel:
main_prog_var = for_parallel(reader=main_prog_var) main_prog_var = parallelize(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
...@@ -435,12 +435,12 @@ def open_files(filenames, ...@@ -435,12 +435,12 @@ def open_files(filenames,
reader=main_prog_reader, pass_num=pass_num) reader=main_prog_reader, pass_num=pass_num)
if for_parallel: if for_parallel:
main_prog_reader = for_parallel(reader=main_prog_reader) main_prog_reader = parallelize(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader) return monkey_patch_reader_methods(main_prog_reader)
def __create_unshared_decorated_reader__(op_type, reader, attrs={}): def __create_shared_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type) var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name) startup_var = startup_blk.create_var(name=var_name)
...@@ -456,7 +456,7 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs={}): ...@@ -456,7 +456,7 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def __create_shared_decorated_reader__(op_type, reader, attrs={}): def __create_unshared_decorated_reader__(op_type, reader, attrs):
new_reader_name = unique_name(op_type) new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block() main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name) new_reader = main_blk.create_var(name=new_reader_name)
...@@ -488,8 +488,9 @@ def multi_pass(reader, pass_num): ...@@ -488,8 +488,9 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def for_parallel(reader): def parallelize(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader) return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(file_obj): def read_file(file_obj):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册