From ee178d5aebaa48f3434ec577ab1b450f4e3d7eab Mon Sep 17 00:00:00 2001 From: JiayiFeng Date: Tue, 10 Apr 2018 04:41:39 +0000 Subject: [PATCH] fix bugs --- paddle/fluid/framework/parallel_executor.cc | 4 +--- paddle/fluid/operators/read_op.cc | 7 ------- .../operators/reader/create_threaded_reader_op.cc | 8 ++++---- python/paddle/fluid/layers/io.py | 13 +++++++------ 4 files changed, 12 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 74945fb4f..1bb089c34 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs( for (auto &var : vars) { auto *main_var = main_scope->FindVar(var); - if (!main_var->IsType()) { + if (main_var == nullptr || !main_var->IsType()) { continue; } auto &main_tensor = main_var->Get(); - auto &dims = main_tensor.dims(); - if (paddle::platform::is_gpu_place(main_tensor.place())) { size_t numel = main_tensor.numel(); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 2925b8a85..4496110cf 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -66,13 +66,6 @@ class ReadOp : public framework::OperatorBase { std::vector out_arg_names = Outputs("Out"); std::vector ins; reader->ReadNext(&ins); - if (ins.empty()) { - reader->ReInit(); - reader->ReadNext(&ins); - PADDLE_ENFORCE( - !ins.empty(), - "Reader can not read the next data even it has been re-initialized."); - } PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); for (size_t i = 0; i < ins.size(); ++i) { auto* out = diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc index 854381e0e..7b10135af 100644 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ b/paddle/fluid/operators/reader/create_threaded_reader_op.cc @@ -111,7 +111,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { "When 'unsafe_mode' is false, invoking 'HasNext()' or " "'ReInit()' is not allowed to avoid unexpected bugs in " "multi-thread environment.") - .SetDefault(false); + .SetDefault(true); AddComment(R"DOC( CreateThreadedReader Operator @@ -134,6 +134,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase { } // namespace paddle namespace reader = paddle::operators::reader; -REGISTER_FILE_READER_OPERATOR(create_threaded_reader, - reader::CreateThreadedReaderOp, - reader::CreateThreadedReaderOpMaker); +REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader, + reader::CreateThreadedReaderOp, + reader::CreateThreadedReaderOpMaker); diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 4901521db..d016ab900 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -350,7 +350,7 @@ def open_recordio_file(filename, main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) if for_parallel: - main_prog_var = for_parallel(reader=main_prog_var) + main_prog_var = parallelize(reader=main_prog_var) return monkey_patch_reader_methods(main_prog_var) @@ -435,12 +435,12 @@ def open_files(filenames, reader=main_prog_reader, pass_num=pass_num) if for_parallel: - main_prog_reader = for_parallel(reader=main_prog_reader) + main_prog_reader = parallelize(reader=main_prog_reader) return monkey_patch_reader_methods(main_prog_reader) -def __create_unshared_decorated_reader__(op_type, reader, attrs={}): +def __create_shared_decorated_reader__(op_type, reader, attrs): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() startup_var = startup_blk.create_var(name=var_name) @@ -456,7 +456,7 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs={}): return monkey_patch_reader_methods(main_prog_var) -def __create_shared_decorated_reader__(op_type, reader, attrs={}): +def __create_unshared_decorated_reader__(op_type, reader, attrs): new_reader_name = unique_name(op_type) main_blk = default_main_program().current_block() new_reader = main_blk.create_var(name=new_reader_name) @@ -488,8 +488,9 @@ def multi_pass(reader, pass_num): 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) -def for_parallel(reader): - return __create_shared_decorated_reader__('create_threaded_reader', reader) +def parallelize(reader): + return __create_shared_decorated_reader__('create_threaded_reader', reader, + {}) def read_file(file_obj): -- GitLab