提交 ee178d5a 编写于 作者: J JiayiFeng

fix bugs

上级 fe48dfcb
......@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var);
if (!main_var->IsType<LoDTensor>()) {
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
auto &main_tensor = main_var->Get<LoDTensor>();
auto &dims = main_tensor.dims();
if (paddle::platform::is_gpu_place(main_tensor.place())) {
size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
......
......@@ -66,13 +66,6 @@ class ReadOp : public framework::OperatorBase {
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->ReadNext(&ins);
PADDLE_ENFORCE(
!ins.empty(),
"Reader can not read the next data even it has been re-initialized.");
}
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
for (size_t i = 0; i < ins.size(); ++i) {
auto* out =
......
......@@ -111,7 +111,7 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
"When 'unsafe_mode' is false, invoking 'HasNext()' or "
"'ReInit()' is not allowed to avoid unexpected bugs in "
"multi-thread environment.")
.SetDefault(false);
.SetDefault(true);
AddComment(R"DOC(
CreateThreadedReader Operator
......@@ -134,6 +134,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
......@@ -350,7 +350,7 @@ def open_recordio_file(filename,
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = for_parallel(reader=main_prog_var)
main_prog_var = parallelize(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
......@@ -435,12 +435,12 @@ def open_files(filenames,
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = for_parallel(reader=main_prog_reader)
main_prog_reader = parallelize(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader)
def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
def __create_shared_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
......@@ -456,7 +456,7 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
return monkey_patch_reader_methods(main_prog_var)
def __create_shared_decorated_reader__(op_type, reader, attrs={}):
def __create_unshared_decorated_reader__(op_type, reader, attrs):
new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
......@@ -488,8 +488,9 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def for_parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader)
def parallelize(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(file_obj):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册