From b094110256c31da3ae002ceaefee5e367b9fcaec Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 2 Nov 2021 14:14:12 +0800 Subject: [PATCH] fix some bug, test=develop (#36888) --- .../framework/new_executor/interpretercore.cc | 7 +++-- .../new_executor/interpretercore_util.cc | 23 +++++++++++---- .../new_executor/new_executor_defs.h | 9 ++++++ .../operators/controlflow/fetch_v2_op.cc | 28 +++++++++++++++++-- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 3ea8b8d309d..a8007c2f26a 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -241,13 +241,14 @@ void InterpreterCore::BuildInplace() { auto& outputs = instr.Outputs(); for (auto& pair : in_to_outs) { auto iter = inputs.find(pair.first); - if (iter != inputs.end()) { + if (iter != inputs.end() && !iter->second.empty()) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { auto iterout = outputs.find(pair.second); - if (iterout != outputs.end()) { + if (iterout != outputs.end() && !iterout->second.empty()) { auto invar = global_scope_->Var(iter->second[0]); auto outvar = global_scope_->Var(iterout->second[0]); - if (invar && outvar) { + if (invar && outvar && invar->IsType() && + outvar->IsType()) { instr.AddInplace(invar, outvar); VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type() << " " << global_scope_->GetNameById(iter->second[0]) diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index a4443b08847..9de03a435ab 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -142,8 +142,8 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, if (nullptr == var_scope->FindVar(var_name)) { var_scope->AddVar(var_desc->Name(), var_desc); } else { - auto* var_desc = var_scope->VarDesc(var_name); - if (nullptr == var_desc) { + auto* var_desc_tmp = var_scope->VarDesc(var_name); + if (nullptr == var_desc_tmp) { VLOG(3) << "update var:" << var_name << " desc from nullptr into " << var_desc; var_scope->VarMetaInfo(var_name).vardesc_ = var_desc; @@ -206,9 +206,22 @@ void apply_device_guard(const OperatorBase* op_base, VLOG(3) << "Switch into CPUPlace by device_guard."; expected_kernel_key->place_ = platform::CPUPlace(); } else if (op_device.find("gpu") != std::string::npos && - platform::is_gpu_place(place)) { - VLOG(3) << "Switch into " << place << " by device_guard."; - expected_kernel_key->place_ = place; + (platform::is_gpu_place(place) || + platform::is_npu_place(place))) { + // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel + // will be executed and a warning will be given at the same time. + if (op_base->SupportGPU()) { + expected_kernel_key->place_ = place; + } else if (op_base->SupportNPU()) { + expected_kernel_key->place_ = place; + } else { + expected_kernel_key->place_ = platform::CPUPlace(); + LOG_FIRST_N(WARNING, 1) + << "Op(" << op_base->Type() + << ") has no CUDA implementation. It will be assigned to CPUPlace."; + } + VLOG(3) << "Switch into " << expected_kernel_key->place_ + << " by device_guard."; } else { PADDLE_THROW( platform::errors::Fatal("Unsupported current place %s", op_device)); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 58b6c924e23..d70243b93fe 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -474,6 +474,15 @@ struct VariableMetaInfo { // TODO(zhiqiu): Maybe we need to add rwlock for VariableScope? class VariableScope : public ScopeBase { public: + VariableScope() { + // for @EMPTY@ variable + var_list_.push_back(nullptr); + name2id_[kEmptyVarName] = 0; + VariableMetaInfo info; + info.var_ref_count_ = 0; + info.vardesc_ = nullptr; + vec_meta_info_.push_back(info); + } Variable* FindVar(const std::string& name) const { auto it = name2id_.find(name); if (it != name2id_.end()) { diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index bf9874c02f6..0837caf9353 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -77,12 +77,35 @@ class FetchV2Op : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { + if (!tensor.IsInitialized()) { + return expected_kernel_type; + } return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + auto *fetch_var = ctx.InputVar("X"); + if (fetch_var == nullptr) { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CPUPlace()); + } + + if (fetch_var->IsType()) { + auto &src_item = fetch_var->Get(); + if (!src_item.IsInitialized()) { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CPUPlace()); + } + } else { + auto &src_item = fetch_var->Get(); + if (src_item.empty() || !src_item[0].IsInitialized()) { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CPUPlace()); + } + } + return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), platform::CPUPlace()); @@ -127,6 +150,9 @@ class FetchV2Kernel { if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); + if (!src_item.IsInitialized()) { + return; + } auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col))); bool check_place = platform::is_cpu_place(src_item.place()) || platform::is_cuda_pinned_place(src_item.place()); @@ -173,9 +199,7 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(true); AddComment(R"DOC( FetchV2 Operator. - It should not be configured by users directly. - )DOC"); } }; -- GitLab